Rumah >Peranti teknologi >AI >Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

王林
王林ke hadapan
2023-04-11 22:31:131003semak imbas

Memandangkan model pembelajaran mesin terus meningkat dalam kerumitan dan keupayaan. Teknik yang berkesan untuk meningkatkan prestasi model besar dan kompleks pada set data kecil ialah penyulingan pengetahuan, yang melibatkan latihan model yang lebih kecil dan lebih cekap untuk meniru tingkah laku model "guru" yang lebih besar.

Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

Dalam artikel ini, kita akan meneroka konsep penyulingan pengetahuan dan cara melaksanakannya dalam PyTorch. Kita akan melihat bagaimana ia boleh digunakan untuk memampatkan model yang besar dan sukar digunakan kepada model yang lebih kecil dan lebih cekap sambil mengekalkan ketepatan dan prestasi model asal.

Kami mula-mula mentakrifkan masalah yang perlu diselesaikan dengan penyulingan pengetahuan.

Kami melatih rangkaian saraf dalam yang besar untuk melaksanakan tugas yang kompleks seperti klasifikasi imej atau terjemahan mesin. Model ini mungkin mempunyai beribu-ribu lapisan dan berjuta-juta parameter, menjadikannya sukar untuk digunakan dalam aplikasi dunia sebenar, peranti tepi, dsb. Dan model yang sangat besar ini juga memerlukan banyak sumber pengkomputeran untuk dijalankan, yang menjadikannya tidak dapat berfungsi pada beberapa platform yang dikekang oleh sumber.

Salah satu cara untuk menyelesaikan masalah ini ialah menggunakan penyulingan pengetahuan untuk memampatkan model besar kepada model yang lebih kecil. Proses ini melibatkan latihan model yang lebih kecil untuk meniru tingkah laku model yang lebih besar dalam tugasan yang diberikan.

Kami akan melakukan contoh penyulingan pengetahuan menggunakan set data x-ray dada daripada Kaggle untuk klasifikasi radang paru-paru. Set data yang kami gunakan disusun ke dalam 3 folder (kereta api, ujian, val) dan mengandungi subfolder untuk setiap kategori imej (Pneumonia/Normal). Terdapat 5,863 imej x-ray (JPEG) dan 2 kategori (pneumonia/normal).

Bandingkan gambar kedua-dua kelas ini:

Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

Pemuatan dan prapemprosesan data tidak ada kena mengena sama ada kami menggunakan penyulingan pengetahuan atau model, kod tertentu coretan mungkin Seperti yang ditunjukkan di bawah:

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

Model Guru

Dalam model guru latar belakang ini, kami menggunakan Resnet-18 dan diperhalusi pada set data ini.

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x

Kod untuk latihan penalaan halus adalah seperti berikut

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

Ini adalah langkah latihan penalaan halus standard Selepas latihan, kita dapat melihat bahawa model itu mencapai 91%. ketepatan pada set ujian Inilah sebabnya kami tidak memilih model yang lebih besar, kerana ketepatan ujian 91 cukup untuk digunakan sebagai model asas.

Kami tahu bahawa model ini mempunyai 11.7 juta parameter, jadi model itu mungkin tidak semestinya dapat menyesuaikan diri dengan peranti tepi atau senario khusus lain.

Model Pelajar

Pelajar kami ialah CNN yang lebih cetek dengan hanya beberapa lapisan dan kira-kira 100k parameter.

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

Ia sangat mudah hanya dengan melihat kod, bukan.

Jika saya boleh melatih rangkaian saraf yang lebih kecil ini, mengapa saya perlu bersusah payah dengan penyulingan pengetahuan, kami akan melampirkan pada akhir hasil latihan rangkaian ini dari awal melalui pelarasan hiperparameter dan cara lain untuk perbandingan.

Tetapi kini kami meneruskan langkah penyulingan pengetahuan kami

Latihan penyulingan pengetahuan

Langkah asas latihan tetap sama, tetapi perbezaannya ialah cara mengira kehilangan latihan akhir , Kami akan menggunakan kehilangan model guru, kehilangan model pelajar dan kehilangan penyulingan bersama-sama untuk mengira kerugian akhir.

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss

Fungsi kerugian ialah jumlah wajaran dua perkara berikut:

  • Kehilangan klasifikasi, dipanggil_sasaran_kerugian pelajar
  • Kehilangan penyulingan, pasangan pelajar dan pasangan guru Kehilangan entropi silang antara nombor

Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

Ringkasnya, model guru kami perlu mengajar pelajar bagaimana untuk "berfikir", yang bermaksud bahawa ia tidak Deterministik sebagai contoh, jika kebarangkalian keluaran akhir model guru ialah [0.53, 0.47], kami berharap pelajar juga akan mendapat keputusan yang sama yang sama, dan perbezaan antara ramalan ini ialah kehilangan penyulingan.

Untuk mengawal kehilangan, terdapat dua parameter utama:

  • Berat kehilangan penyulingan: 0 bermakna kami hanya mempertimbangkan kehilangan penyulingan dan sebaliknya.
  • Suhu: Mengukur ketidakpastian ramalan guru.

Dalam perkara di atas, nilai alfa dan suhu adalah berdasarkan hasil terbaik yang telah kami cuba beberapa kombinasi.

Perbandingan keputusan

Ini ialah ringkasan jadual bagi eksperimen ini.

Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch

Kita dapat melihat dengan jelas manfaat besar yang diperoleh daripada menggunakan CNN yang lebih kecil (99.14%), lebih cetek: ketepatan yang lebih baik berbanding latihan tanpa penyulingan 10 mata, dan 11 kali lebih cepat daripada Resnet-18! Iaitu, model kecil kami benar-benar mempelajari sesuatu yang berguna daripada model besar.


Atas ialah kandungan terperinci Contoh kod untuk penyulingan pengetahuan menggunakan PyTorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:51cto.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam