>  기사  >  기술 주변기기  >  PyTorch를 사용한 지식 추출을 위한 코드 예제

PyTorch를 사용한 지식 추출을 위한 코드 예제

王林
王林앞으로
2023-04-11 22:31:13926검색

머신러닝 모델의 복잡성과 기능이 계속해서 증가함에 따라. 작은 데이터 세트에서 크고 복잡한 모델의 성능을 향상시키는 효과적인 기술은 더 큰 "교사" 모델의 동작을 모방하기 위해 더 작고 효율적인 모델을 훈련시키는 지식 증류입니다.

PyTorch를 사용한 지식 추출을 위한 코드 예제

이 글에서는 지식 증류의 개념과 이를 PyTorch에서 구현하는 방법을 살펴보겠습니다. 크고 다루기 힘든 모델을 더 작고 더 효율적인 모델로 압축하면서도 원래 모델의 정확성과 성능을 유지하는 데 이 방법을 사용할 수 있는 방법을 살펴보겠습니다.

먼저 지식 증류를 통해 해결해야 할 문제를 정의합니다.

우리는 이미지 분류나 기계 번역과 같은 복잡한 작업을 수행하기 위해 대규모 심층 신경망을 훈련했습니다. 이 모델에는 수천 개의 레이어와 수백만 개의 매개변수가 있을 수 있으므로 실제 애플리케이션, 에지 장치 등에 배포하기가 어렵습니다. 그리고 이 매우 큰 모델을 실행하려면 많은 컴퓨팅 리소스가 필요하므로 리소스가 제한된 일부 플랫폼에서는 작동할 수 없습니다.

이 문제를 해결하는 한 가지 방법은 지식 증류를 사용하여 대형 모델을 작은 모델로 압축하는 것입니다. 이 프로세스에는 주어진 작업에서 더 큰 모델의 동작을 모방하기 위해 더 작은 모델을 훈련시키는 것이 포함됩니다.

폐렴 분류를 위해 Kaggle의 흉부 엑스레이 데이터 세트를 사용한 지식 증류의 예를 사용하겠습니다. 우리가 사용한 데이터 세트는 3개의 폴더(train, test, val)로 구성되어 있으며 각 이미지 카테고리(Pneumonia/Normal)에 대한 하위 폴더를 포함합니다. 5,863개의 엑스레이 이미지(JPEG)와 2개의 카테고리(폐렴/정상)가 있습니다.

이 두 클래스의 그림을 비교해 보세요.

PyTorch를 사용한 지식 추출을 위한 코드 예제

데이터 로드 및 전처리는 지식 증류 또는 특정 모델을 사용하는지 여부와 무관하며 코드 조각은 다음과 같습니다.

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

Teacher Model

여기서 context 중간 교사 모델의 경우 Resnet-18을 사용하고 이 데이터 세트에서 미세 조정합니다.

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x

미세 조정 훈련 코드는 다음과 같습니다

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

이것은 표준 미세 조정 훈련 단계입니다. 훈련 후에는 모델이 테스트 세트에서 91%의 정확도를 달성한 것을 볼 수 있습니다. 더 큰 모델을 선택하는 이유는 테스트 91의 정확도가 기본 클래스 모델로 사용하기에 충분하기 때문입니다.

우리는 모델에 1,170만 개의 매개변수가 있다는 것을 알고 있으므로 엣지 장치나 기타 특정 시나리오에 반드시 적응하지 못할 수도 있습니다.

학생 모델

저희 학생은 몇 개의 레이어와 약 100,000개의 매개변수만 있는 얕은 CNN입니다.

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

코드를 보면 아주 간단하죠.

이 작은 신경망을 단순히 훈련할 수 있다면 왜 지식 증류에 신경을 써야 할까요? 마침내 하이퍼파라미터 조정 및 기타 비교 수단을 통해 이 네트워크를 처음부터 훈련한 결과를 첨부하겠습니다.

이제 우리는 지식 증류 단계를 계속합니다

지식 증류 훈련

훈련의 기본 단계는 동일하지만 차이점은 최종 훈련 손실을 계산하는 방법입니다. 교사 모델 손실, 학생 모델을 사용하겠습니다. 손실 및 증류 손실은 최종 손실과 함께 계산됩니다.

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss

손실 함수는 다음 두 가지의 가중 합입니다.

  • student_target_loss라고 불리는 분류 손실
  • 증류 손실, 학생 로그와 교사 로그 간의 교차 엔트로피 손실

PyTorch를 사용한 지식 추출을 위한 코드 예제

간단히 말하면 교사 모델 학생들에게 불확실성을 나타내는 "사고" 방법을 가르쳐야 합니다. 예를 들어 교사 모델의 최종 출력 확률이 [0.53, 0.47]인 경우 학생들도 동일한 유사한 결과를 얻을 수 있기를 바랍니다. 이러한 예측은 증류 손실입니다.

손실을 제어하기 위해 두 가지 주요 매개변수가 있습니다.

  • 증류 손실의 가중치: 0은 증류 손실만 고려한다는 의미이며 그 반대의 경우도 마찬가지입니다.
  • 온도: 교사 예측의 불확실성을 측정합니다.

위의 점에서 알파와 온도의 값은 몇 가지 조합을 시도한 최상의 결과를 기반으로 합니다.

결과 비교

이것은 이 실험을 표 형식으로 요약한 것입니다.

PyTorch를 사용한 지식 추출을 위한 코드 예제

더 작고(99.14%) 더 얕은 CNN을 사용하면 얻을 수 있는 엄청난 이점을 명확하게 확인할 수 있습니다. 즉, 증류를 사용하지 않은 훈련에 비해 정확도가 10포인트 향상되고 Resnet-18 Times보다 11포인트 빠릅니다! 작은 모델은 큰 모델로부터 정말 유용한 것을 배웠습니다.


위 내용은 PyTorch를 사용한 지식 추출을 위한 코드 예제의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제