Rumah > Artikel > Peranti teknologi > Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch
Memandangkan model pembelajaran mesin terus meningkat dalam kerumitan dan keupayaan. Teknik yang berkesan untuk meningkatkan prestasi model besar dan kompleks pada set data kecil ialah penyulingan pengetahuan, yang melibatkan latihan model yang lebih kecil dan lebih cekap untuk meniru tingkah laku model "guru" yang lebih besar.
Dalam artikel ini, kita akan meneroka konsep penyulingan pengetahuan dan cara melaksanakannya dalam PyTorch. Kita akan melihat bagaimana ia boleh digunakan untuk memampatkan model yang besar dan sukar digunakan kepada model yang lebih kecil dan lebih cekap sambil mengekalkan ketepatan dan prestasi model asal.
Kami mula-mula mentakrifkan masalah yang perlu diselesaikan dengan penyulingan pengetahuan.
Kami melatih rangkaian saraf dalam yang besar untuk melaksanakan tugas yang kompleks seperti klasifikasi imej atau terjemahan mesin. Model ini mungkin mempunyai beribu-ribu lapisan dan berjuta-juta parameter, menjadikannya sukar untuk digunakan dalam aplikasi dunia sebenar, peranti tepi, dsb. Dan model yang sangat besar ini juga memerlukan banyak sumber pengkomputeran untuk dijalankan, yang menjadikannya tidak dapat berfungsi pada beberapa platform yang dikekang oleh sumber.
Salah satu cara untuk menyelesaikan masalah ini ialah menggunakan penyulingan pengetahuan untuk memampatkan model besar kepada model yang lebih kecil. Proses ini melibatkan latihan model yang lebih kecil untuk meniru tingkah laku model yang lebih besar dalam tugasan yang diberikan.
Kami akan melakukan contoh penyulingan pengetahuan menggunakan set data x-ray dada daripada Kaggle untuk klasifikasi radang paru-paru. Set data yang kami gunakan disusun ke dalam 3 folder (kereta api, ujian, val) dan mengandungi subfolder untuk setiap kategori imej (Pneumonia/Normal). Terdapat 5,863 imej x-ray (JPEG) dan 2 kategori (pneumonia/normal).
Bandingkan gambar kedua-dua kelas ini:
Pemuatan dan prapemprosesan data tidak ada kena mengena sama ada kami menggunakan penyulingan pengetahuan atau model, kod tertentu coretan mungkin Seperti yang ditunjukkan di bawah:
transforms_train = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transforms_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_data = ImageFolder(root=train_dir, transform=transforms_train) test_data = ImageFolder(root=test_dir, transform=transforms_test) train_loader = DataLoader(train_data, batch_size=32, shuffle=True) test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
Dalam model guru latar belakang ini, kami menggunakan Resnet-18 dan diperhalusi pada set data ini.
import torch import torch.nn as nn import torchvision class TeacherNet(nn.Module): def __init__(self): super().__init__() self.model = torchvision.models.resnet18(pretrained=True) for params in self.model.parameters(): params.requires_grad_ = False n_filters = self.model.fc.in_features self.model.fc = nn.Linear(n_filters, 2) def forward(self, x): x = self.model(x) return x
Kod untuk latihan penalaan halus adalah seperti berikut
def train(model, train_loader, test_loader, optimizer, criterion, device): dataloaders = {'train': train_loader, 'val': test_loader} for epoch in range(30): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in tqdm.tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
Ini adalah langkah latihan penalaan halus standard Selepas latihan, kita dapat melihat bahawa model itu mencapai 91%. ketepatan pada set ujian Inilah sebabnya kami tidak memilih model yang lebih besar, kerana ketepatan ujian 91 cukup untuk digunakan sebagai model asas.
Kami tahu bahawa model ini mempunyai 11.7 juta parameter, jadi model itu mungkin tidak semestinya dapat menyesuaikan diri dengan peranti tepi atau senario khusus lain.
Pelajar kami ialah CNN yang lebih cetek dengan hanya beberapa lapisan dan kira-kira 100k parameter.
class StudentNet(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc = nn.Linear(4 * 112 * 112, 2) def forward(self, x): out = self.layer1(x) out = out.view(out.size(0), -1) out = self.fc(out) return out
Ia sangat mudah hanya dengan melihat kod, bukan.
Jika saya boleh melatih rangkaian saraf yang lebih kecil ini, mengapa saya perlu bersusah payah dengan penyulingan pengetahuan, kami akan melampirkan pada akhir hasil latihan rangkaian ini dari awal melalui pelarasan hiperparameter dan cara lain untuk perbandingan.
Tetapi kini kami meneruskan langkah penyulingan pengetahuan kami
Langkah asas latihan tetap sama, tetapi perbezaannya ialah cara mengira kehilangan latihan akhir , Kami akan menggunakan kehilangan model guru, kehilangan model pelajar dan kehilangan penyulingan bersama-sama untuk mengira kerugian akhir.
class DistillationLoss: def __init__(self): self.student_loss = nn.CrossEntropyLoss() self.distillation_loss = nn.KLDivLoss() self.temperature = 1 self.alpha = 0.25 def __call__(self, student_logits, student_target_loss, teacher_logits): distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1)) loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss return loss
Fungsi kerugian ialah jumlah wajaran dua perkara berikut:
Ringkasnya, model guru kami perlu mengajar pelajar bagaimana untuk "berfikir", yang bermaksud bahawa ia tidak Deterministik sebagai contoh, jika kebarangkalian keluaran akhir model guru ialah [0.53, 0.47], kami berharap pelajar juga akan mendapat keputusan yang sama yang sama, dan perbezaan antara ramalan ini ialah kehilangan penyulingan.
Untuk mengawal kehilangan, terdapat dua parameter utama:
Dalam perkara di atas, nilai alfa dan suhu adalah berdasarkan hasil terbaik yang telah kami cuba beberapa kombinasi.
Ini ialah ringkasan jadual bagi eksperimen ini.
Kita dapat melihat dengan jelas manfaat besar yang diperoleh daripada menggunakan CNN yang lebih kecil (99.14%), lebih cetek: ketepatan yang lebih baik berbanding latihan tanpa penyulingan 10 mata, dan 11 kali lebih cepat daripada Resnet-18! Iaitu, model kecil kami benar-benar mempelajari sesuatu yang berguna daripada model besar.
Atas ialah kandungan terperinci Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!