Home >Technology peripherals >AI >Code examples for knowledge distillation using PyTorch
As machine learning models continue to increase in complexity and capabilities. An effective technique for improving the performance of large, complex models on small data sets is knowledge distillation, which involves training a smaller, more efficient model to mimic the behavior of a larger "teacher" model.
In this article, we will explore the concept of knowledge distillation and how to implement it in PyTorch. We'll see how it can be used to compress a large, unwieldy model into a smaller, more efficient model while still retaining the accuracy and performance of the original model.
We first define the problem to be solved by knowledge distillation.
We trained a large deep neural network to perform complex tasks such as image classification or machine translation. This model may have thousands of layers and millions of parameters, making it difficult to deploy in real-world applications, edge devices, etc. And this very large model also requires a lot of computing resources to run, which makes it unable to work on some resource-constrained platforms.
One way to solve this problem is to use knowledge distillation to compress large models into smaller models. This process involves training a smaller model to mimic the behavior of the larger model in a given task.
We will do an example of knowledge distillation using the chest x-ray dataset from Kaggle for pneumonia classification. The dataset we used is organized into 3 folders (train, test, val) and contains subfolders for each image category (Pneumonia/Normal). There are 5,863 x-ray images (JPEG) and 2 categories (pneumonia/normal).
Compare the pictures of these two classes:
#The loading and preprocessing of data has nothing to do with whether we use knowledge distillation or a specific model, the code snippet may As shown below:
transforms_train = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_data = ImageFolder(root=train_dir, transform=transforms_train) test_data = ImageFolder(root=test_dir, transform=transforms_test) train_loader = DataLoader(train_data, batch_size=32, shuffle=True) test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
In this background teacher model we use Resnet-18 and fine-tuned on this dataset.
import torch import torch.nn as nn import torchvision class TeacherNet(nn.Module): def __init__(self): super().__init__() self.model = torchvision.models.resnet18(pretrained=True) for params in self.model.parameters(): params.requires_grad_ = False n_filters = self.model.fc.in_features self.model.fc = nn.Linear(n_filters, 2) def forward(self, x): x = self.model(x) return x
The code for fine-tuning training is as follows
def train(model, train_loader, test_loader, optimizer, criterion, device): dataloaders = {'train': train_loader, 'val': test_loader} for epoch in range(30): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in tqdm.tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
This is a standard fine-tuning training step. After training, we can see that the model achieved 91% accuracy on the test set, which is also That's why we didn't choose a larger model, because the accuracy of test 91 is enough to be used as a base model.
We know that the model has 11.7 million parameters, so it may not necessarily be able to adapt to edge devices or other specific scenarios.
Our student is a shallower CNN with only a few layers and about 100k parameters.
class StudentNet(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc = nn.Linear(4 * 112 * 112, 2) def forward(self, x): out = self.layer1(x) out = out.view(out.size(0), -1) out = self.fc(out) return out
It’s very simple when you look at the code, right.
Why should I bother with knowledge distillation if I can simply train this smaller neural network? We will attach at the end the results of training this network from scratch through hyperparameter adjustment and other means for comparison.
But now we continue our knowledge distillation steps
The basic steps of training are unchanged, but the difference is how to calculate the final training loss, We will use the teacher model loss, student model loss and distillation loss together to calculate the final loss.
class DistillationLoss: def __init__(self): self.student_loss = nn.CrossEntropyLoss() self.distillation_loss = nn.KLDivLoss() self.temperature = 1 self.alpha = 0.25 def __call__(self, student_logits, student_target_loss, teacher_logits): distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1)) loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss return loss
The loss function is the weighted sum of the following two things:
Simply put, our teacher model needs to teach students how to "think", which refers to its uncertainty ;For example, if the final output probability of the teacher model is [0.53, 0.47], we hope that the student will also get the same similar results, and the difference between these predictions is the distillation loss.
In order to control the loss, there are two main parameters:
In the above points, the values of alpha and temperature are based on the best results obtained by some combinations we tried.
This is a tabular summary of this experiment.
We can clearly see the huge benefit gained from using a smaller (99.14%), shallower CNN: the accuracy is improved compared to training without distillation 10 points, and 11 times faster than Resnet-18! That is, our small model really learned something useful from the large model.
The above is the detailed content of Code examples for knowledge distillation using PyTorch. For more information, please follow other related articles on the PHP Chinese website!