Heim >Technologie-Peripheriegeräte >KI >Codebeispiel für die Wissensdestillation mit PyTorch
Da Modelle für maschinelles Lernen immer komplexer und leistungsfähiger werden. Eine wirksame Technik zur Verbesserung der Leistung großer, komplexer Modelle bei kleinen Datensätzen ist die Wissensdestillation, bei der ein kleineres, effizienteres Modell trainiert wird, um das Verhalten eines größeren „Lehrer“-Modells nachzuahmen.
In diesem Artikel untersuchen wir das Konzept der Wissensdestillation und wie man es in PyTorch implementiert. Wir werden sehen, wie es verwendet werden kann, um ein großes, unhandliches Modell in ein kleineres, effizienteres Modell zu komprimieren und dennoch die Genauigkeit und Leistung des Originalmodells beizubehalten.
Wir definieren zunächst das zu lösende Problem durch Wissensdestillation.
Wir haben ein großes tiefes neuronales Netzwerk trainiert, um komplexe Aufgaben wie Bildklassifizierung oder maschinelle Übersetzung auszuführen. Dieses Modell verfügt möglicherweise über Tausende von Schichten und Millionen von Parametern, was die Bereitstellung in realen Anwendungen, Edge-Geräten usw. erschwert. Darüber hinaus erfordert die Ausführung dieses sehr großen Modells auch viele Rechenressourcen, sodass es auf einigen Plattformen mit eingeschränkten Ressourcen nicht funktionieren kann.
Eine Möglichkeit, dieses Problem zu lösen, besteht darin, mithilfe der Wissensdestillation große Modelle in kleinere Modelle zu komprimieren. Bei diesem Prozess wird ein kleineres Modell trainiert, um das Verhalten des größeren Modells bei einer bestimmten Aufgabe nachzuahmen.
Wir werden ein Beispiel für die Wissensdestillation anhand des Brust-Röntgendatensatzes von Kaggle zur Klassifizierung von Lungenentzündungen verwenden. Der von uns verwendete Datensatz ist in drei Ordner (Train, Test, Val) organisiert und enthält Unterordner für jede Bildkategorie (Pneumonie/Normal). Es gibt 5.863 Röntgenbilder (JPEG) und 2 Kategorien (Lungenentzündung/normal).
Vergleichen Sie die Bilder dieser beiden Klassen:
Das Laden und Vorverarbeiten von Daten ist unabhängig davon, ob wir die Wissensdestillation oder ein bestimmtes Modell verwenden. Der Codeausschnitt könnte so aussehen:
transforms_train = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_data = ImageFolder(root=train_dir, transform=transforms_train) test_data = ImageFolder(root=test_dir, transform=transforms_test) train_loader = DataLoader(train_data, batch_size=32, shuffle=True) test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
Hier Kontext Für das Mittellehrermodell verwenden wir Resnet-18 und optimieren es anhand dieses Datensatzes.
import torch import torch.nn as nn import torchvision class TeacherNet(nn.Module): def __init__(self): super().__init__() self.model = torchvision.models.resnet18(pretrained=True) for params in self.model.parameters(): params.requires_grad_ = False n_filters = self.model.fc.in_features self.model.fc = nn.Linear(n_filters, 2) def forward(self, x): x = self.model(x) return x
Der Code für das Feinabstimmungstraining lautet wie folgt:
def train(model, train_loader, test_loader, optimizer, criterion, device): dataloaders = {'train': train_loader, 'val': test_loader} for epoch in range(30): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in tqdm.tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
Dies ist ein Standard-Feinabstimmungstrainingsschritt. Nach dem Training können wir sehen, dass das Modell im Testsatz eine Genauigkeit von 91 % erreicht hat, weshalb wir dies nicht getan haben Wählen Sie ein größeres Modell, da die Genauigkeit von Test 91 ausreicht, um als Basisklassenmodell verwendet zu werden.
Wir wissen, dass das Modell über 11,7 Millionen Parameter verfügt und daher möglicherweise nicht unbedingt in der Lage ist, sich an Edge-Geräte oder andere spezifische Szenarien anzupassen.
Unser Student ist ein flacheres CNN mit nur wenigen Schichten und etwa 100.000 Parametern.
class StudentNet(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc = nn.Linear(4 * 112 * 112, 2) def forward(self, x): out = self.layer1(x) out = out.view(out.size(0), -1) out = self.fc(out) return out
Es ist ganz einfach, wenn man sich den Code ansieht, richtig.
Wenn ich dieses kleinere neuronale Netzwerk einfach trainieren kann, warum sollte ich mich dann mit der Wissensdestillation beschäftigen? Wir werden schließlich die Ergebnisse des Trainings dieses Netzwerks von Grund auf durch Hyperparameteranpassung und andere Vergleichsmethoden anhängen.
Aber jetzt fahren wir mit unseren Wissensdestillationsschritten fort
Die grundlegenden Schritte des Trainings sind die gleichen, aber der Unterschied besteht darin, wie der endgültige Trainingsverlust berechnet wird. Wir verwenden den Verlust des Lehrermodells und das Schülermodell Verlust und Der Destillationsverlust wird zusammen mit dem Endverlust berechnet.
class DistillationLoss: def __init__(self): self.student_loss = nn.CrossEntropyLoss() self.distillation_loss = nn.KLDivLoss() self.temperature = 1 self.alpha = 0.25 def __call__(self, student_logits, student_target_loss, teacher_logits): distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1)) loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss return loss
Die Verlustfunktion ist die gewichtete Summe der folgenden zwei Dinge:
Einfach ausgedrückt, unser Lehrermodell Wenn die endgültige Ausgabewahrscheinlichkeit des Lehrermodells beispielsweise [0,53, 0,47] beträgt, hoffen wir, dass die Schüler auch die gleichen ähnlichen Ergebnisse erhalten. Der Unterschied zwischen Diese Vorhersagen sind der Destillationsverlust.
Um den Verlust zu kontrollieren, gibt es zwei Hauptparameter:
In den oben genannten Punkten basieren die Werte von Alpha und Temperatur auf den besten Ergebnissen, die wir bei einigen Kombinationen ausprobiert haben.
Dies ist eine tabellarische Zusammenfassung dieses Experiments.
Wir können deutlich die enormen Vorteile erkennen, die durch die Verwendung eines kleineren (99,14 %) und flacheren CNN erzielt werden: 10 Punkte Genauigkeitsverbesserung im Vergleich zum Training ohne Destillation und 11 Punkte schneller als Resnet-18 Times! Das kleine Modell hat wirklich etwas Nützliches vom großen Modell gelernt.
Das obige ist der detaillierte Inhalt vonCodebeispiel für die Wissensdestillation mit PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!