Heim  >  Artikel  >  Technologie-Peripheriegeräte  >  Codebeispiel für die Wissensdestillation mit PyTorch

Codebeispiel für die Wissensdestillation mit PyTorch

王林
王林nach vorne
2023-04-11 22:31:13953Durchsuche

Da Modelle für maschinelles Lernen immer komplexer und leistungsfähiger werden. Eine wirksame Technik zur Verbesserung der Leistung großer, komplexer Modelle bei kleinen Datensätzen ist die Wissensdestillation, bei der ein kleineres, effizienteres Modell trainiert wird, um das Verhalten eines größeren „Lehrer“-Modells nachzuahmen.

Codebeispiel für die Wissensdestillation mit PyTorch

In diesem Artikel untersuchen wir das Konzept der Wissensdestillation und wie man es in PyTorch implementiert. Wir werden sehen, wie es verwendet werden kann, um ein großes, unhandliches Modell in ein kleineres, effizienteres Modell zu komprimieren und dennoch die Genauigkeit und Leistung des Originalmodells beizubehalten.

Wir definieren zunächst das zu lösende Problem durch Wissensdestillation.

Wir haben ein großes tiefes neuronales Netzwerk trainiert, um komplexe Aufgaben wie Bildklassifizierung oder maschinelle Übersetzung auszuführen. Dieses Modell verfügt möglicherweise über Tausende von Schichten und Millionen von Parametern, was die Bereitstellung in realen Anwendungen, Edge-Geräten usw. erschwert. Darüber hinaus erfordert die Ausführung dieses sehr großen Modells auch viele Rechenressourcen, sodass es auf einigen Plattformen mit eingeschränkten Ressourcen nicht funktionieren kann.

Eine Möglichkeit, dieses Problem zu lösen, besteht darin, mithilfe der Wissensdestillation große Modelle in kleinere Modelle zu komprimieren. Bei diesem Prozess wird ein kleineres Modell trainiert, um das Verhalten des größeren Modells bei einer bestimmten Aufgabe nachzuahmen.

Wir werden ein Beispiel für die Wissensdestillation anhand des Brust-Röntgendatensatzes von Kaggle zur Klassifizierung von Lungenentzündungen verwenden. Der von uns verwendete Datensatz ist in drei Ordner (Train, Test, Val) organisiert und enthält Unterordner für jede Bildkategorie (Pneumonie/Normal). Es gibt 5.863 Röntgenbilder (JPEG) und 2 Kategorien (Lungenentzündung/normal).

Vergleichen Sie die Bilder dieser beiden Klassen:

Codebeispiel für die Wissensdestillation mit PyTorch

Das Laden und Vorverarbeiten von Daten ist unabhängig davon, ob wir die Wissensdestillation oder ein bestimmtes Modell verwenden. Der Codeausschnitt könnte so aussehen:

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

Lehrermodell

Hier Kontext Für das Mittellehrermodell verwenden wir Resnet-18 und optimieren es anhand dieses Datensatzes.

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x

Der Code für das Feinabstimmungstraining lautet wie folgt:

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

Dies ist ein Standard-Feinabstimmungstrainingsschritt. Nach dem Training können wir sehen, dass das Modell im Testsatz eine Genauigkeit von 91 % erreicht hat, weshalb wir dies nicht getan haben Wählen Sie ein größeres Modell, da die Genauigkeit von Test 91 ausreicht, um als Basisklassenmodell verwendet zu werden.

Wir wissen, dass das Modell über 11,7 Millionen Parameter verfügt und daher möglicherweise nicht unbedingt in der Lage ist, sich an Edge-Geräte oder andere spezifische Szenarien anzupassen.

Studentenmodell

Unser Student ist ein flacheres CNN mit nur wenigen Schichten und etwa 100.000 Parametern.

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

Es ist ganz einfach, wenn man sich den Code ansieht, richtig.

Wenn ich dieses kleinere neuronale Netzwerk einfach trainieren kann, warum sollte ich mich dann mit der Wissensdestillation beschäftigen? Wir werden schließlich die Ergebnisse des Trainings dieses Netzwerks von Grund auf durch Hyperparameteranpassung und andere Vergleichsmethoden anhängen.

Aber jetzt fahren wir mit unseren Wissensdestillationsschritten fort

Wissensdestillationstraining

Die grundlegenden Schritte des Trainings sind die gleichen, aber der Unterschied besteht darin, wie der endgültige Trainingsverlust berechnet wird. Wir verwenden den Verlust des Lehrermodells und das Schülermodell Verlust und Der Destillationsverlust wird zusammen mit dem Endverlust berechnet.

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss

Die Verlustfunktion ist die gewichtete Summe der folgenden zwei Dinge:

  • Klassifizierungsverlust, genannt student_target_loss
  • Destillationsverlust, Kreuzentropieverlust zwischen Schülerlogarithmus und Lehrerlogarithmus

Codebeispiel für die Wissensdestillation mit PyTorch

Einfach ausgedrückt, unser Lehrermodell Wenn die endgültige Ausgabewahrscheinlichkeit des Lehrermodells beispielsweise [0,53, 0,47] beträgt, hoffen wir, dass die Schüler auch die gleichen ähnlichen Ergebnisse erhalten. Der Unterschied zwischen Diese Vorhersagen sind der Destillationsverlust.

Um den Verlust zu kontrollieren, gibt es zwei Hauptparameter:

  • Das Gewicht des Destillationsverlusts: 0 bedeutet, dass wir nur den Destillationsverlust berücksichtigen und umgekehrt.
  • Temperatur: Misst die Unsicherheit von Lehrervorhersagen.

In den oben genannten Punkten basieren die Werte von Alpha und Temperatur auf den besten Ergebnissen, die wir bei einigen Kombinationen ausprobiert haben.

Ergebnisvergleich

Dies ist eine tabellarische Zusammenfassung dieses Experiments.

Codebeispiel für die Wissensdestillation mit PyTorch

Wir können deutlich die enormen Vorteile erkennen, die durch die Verwendung eines kleineren (99,14 %) und flacheren CNN erzielt werden: 10 Punkte Genauigkeitsverbesserung im Vergleich zum Training ohne Destillation und 11 Punkte schneller als Resnet-18 Times! Das kleine Modell hat wirklich etwas Nützliches vom großen Modell gelernt.


Das obige ist der detaillierte Inhalt vonCodebeispiel für die Wissensdestillation mit PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen