>  기사  >  기술 주변기기  >  PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

WBOY
WBOY앞으로
2023-04-09 10:51:051318검색

최근에는 딥러닝 기반 모델이 객체 감지, 이미지 인식 등의 작업에서 좋은 성능을 발휘했습니다. 1,000개의 서로 다른 개체 분류가 포함된 ImageNet과 같은 까다로운 이미지 분류 데이터 세트에서 일부 모델은 이제 인간 수준을 초과합니다. 그러나 이러한 모델은 감독된 훈련 프로세스에 의존하고, 레이블이 지정된 훈련 데이터의 가용성에 크게 영향을 받으며, 모델이 감지할 수 있는 클래스는 훈련받은 클래스로 제한됩니다.

훈련 중에 모든 클래스에 대해 레이블이 지정된 이미지가 충분하지 않기 때문에 이러한 모델은 실제 환경에서는 덜 유용할 수 있습니다. 그리고 우리는 모델이 훈련 중에 보지 못한 클래스를 인식할 수 있기를 원합니다. 왜냐하면 모든 잠재적 객체의 이미지를 훈련하는 것은 거의 불가능하기 때문입니다. 몇 가지 샘플을 통해 학습할 문제를 Few-Shot 학습이라고 합니다.

퓨샷 학습이란 무엇인가요?

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

퓨샷 학습은 머신러닝의 하위 분야입니다. 여기에는 몇 가지 훈련 샘플과 감독 데이터만으로 새로운 데이터를 분류하는 작업이 포함됩니다. 우리가 만든 모델은 적은 수의 훈련 샘플만으로도 상당히 잘 작동합니다.

다음 시나리오를 고려해보세요. 의료 분야에서는 일부 흔하지 않은 질병의 경우 훈련을 위한 X선 이미지가 충분하지 않을 수 있습니다. 이러한 시나리오의 경우 몇 번의 학습 분류기를 구축하는 것이 완벽한 솔루션입니다.

작은 표본의 변형

일반적으로 연구자들은 네 가지 유형을 식별했습니다.

  1. N-Shot Learning(NSL)
  2. Few-Shot Learning(FSL)
  3. One-Shot Learning(OSL)
  4. Zero-Shot 학습(ZSL)

FSL에 대해 이야기할 때 일반적으로 N-way-K-Shot 분류를 언급합니다. N은 클래스 수를 나타내고, K는 각 클래스에서 훈련할 샘플 수를 나타냅니다. 그래서 N-Shot Learning은 다른 모든 개념보다 더 넓은 개념으로 간주됩니다. Few-Shot, One-Shot 및 Zero-Shot은 NSL의 하위 분야라고 할 수 있습니다. 제로샷 학습은 훈련 예제 없이 보이지 않는 클래스를 분류하는 것을 목표로 합니다.

원샷 학습에서는 수업당 샘플이 하나만 있습니다. Few-Shot은 클래스당 2~5개의 샘플을 가지고 있습니다. 이는 Few-Shot이 One-Shot 학습의 보다 유연한 버전임을 의미합니다.

소형 샘플 학습 방법

Few Shot Learning 문제를 해결할 때는 일반적으로 두 가지 방법을 고려해야 합니다.

Data Level Approach (DLA)

이 전략은 매우 간단합니다. 단, 솔리드 모델을 생성할 데이터가 충분하지 않은 경우 언더슈팅 피팅과 과적합을 방지하려면 더 많은 데이터를 추가해야 합니다. 이 때문에 더 큰 기본 데이터 세트에서 더 많은 데이터를 활용하면 많은 FSL 문제를 해결할 수 있습니다. 기본 데이터세트의 주목할만한 특징은 Few-Shot 챌린지에 대한 지원 세트를 구성하는 클래스가 부족하다는 것입니다. 예를 들어, 특정 종의 새를 분류하려는 경우 기본 데이터 세트에는 다른 많은 새의 사진이 포함될 수 있습니다.

PLA(매개변수 수준 접근 방식)

매개변수 수준 관점에서 볼 때 Few-Shot 학습 샘플은 일반적으로 큰 고차원 공간을 갖기 때문에 과적합되기가 상대적으로 쉽습니다. 매개변수 공간을 제한하고 정규화를 사용하며 적절한 손실 함수를 사용하면 이 문제를 해결하는 데 도움이 됩니다. 일반화를 위해 모델에서는 소수의 훈련 샘플이 사용됩니다.

모델을 광범위한 매개변수 공간으로 안내하여 성능을 향상시킬 수 있습니다. 일반적인 최적화 방법은 훈련 데이터 부족으로 인해 정확한 결과를 생성하지 못할 수 있습니다.

위의 이유로 매개변수 공간을 통해 최적의 경로를 찾도록 모델을 교육하면 최상의 예측 결과를 얻을 수 있습니다. 이러한 접근 방식을 메타 학습이라고 합니다.

소표본 학습 이미지 분류 알고리즘

일반적인 소표본 학습 방법에는 4가지가 있습니다.

모델 독립적 메타 학습 모델 독립적 메타 학습

Gradient-based Meta-Learning(GBML) 원리는 MAML Base입니다. GBML에서 메타 학습자는 기본 모델을 훈련하고 모든 작업 표현에 걸쳐 공유 기능을 학습함으로써 사전 경험을 얻습니다. 학습할 새로운 작업이 있을 때마다 메타 학습자는 기존 경험과 새 작업에서 제공되는 최소한의 새로운 훈련 데이터를 사용하여 미세 조정됩니다.

일반적으로 매개변수를 무작위로 초기화하고 여러 번 업데이트하면 알고리즘이 좋은 성능으로 수렴되지 않습니다. MAML은 이 문제를 해결하려고 시도합니다. MAML은 몇 가지 경사 단계와 과적합 없이 메타 매개변수 학습기의 안정적인 초기화를 제공하므로 새로운 작업을 최적으로 신속하게 학습할 수 있습니다.

단계는 다음과 같습니다.

  1. 메타 학습자는 각 에피소드가 시작될 때 자체 복사본 C를 생성하고,
  2. C는 이 에피소드에 대해 훈련되고(기본 모델의 도움으로),
  3. C는 예측을 쌍으로 만듭니다.
  4. 이 예측에서 계산된 손실은 C를 업데이트하는 데 사용됩니다.
  5. 이는 모든 에피소드에 대한 훈련이 완료될 때까지 계속됩니다.

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

이 기술의 가장 큰 장점은 메타 학습 알고리즘 선택과 독립적으로 간주된다는 것입니다. 따라서 MAML 방법은 빠른 적응이 필요한 많은 기계 학습 알고리즘, 특히 심층 신경망에서 널리 사용됩니다.

Matching Networks

FSL 문제를 해결하기 위해 만들어진 첫 번째 메트릭 학습 방법은 MN(Matching Network)입니다.

Few-Shot Learning 문제를 해결하기 위해 매칭 네트워크 방법을 사용하려면 대규모 기본 데이터 세트가 필요합니다. .

이 데이터 세트를 여러 에피소드로 분할한 후 각 에피소드에 대해 매칭 네트워크는 다음을 수행합니다.

  • 지원 세트와 쿼리 세트의 각 이미지는 해당 기능을 출력하는 CNN에 공급됩니다.
  • 쿼리 임베딩 소프트맥스로 분류된 포함된 특징의 코사인 거리를 얻기 위해 지원 세트에서 훈련된 모델을 사용하여 이미지를 얻습니다.
  • 분류 결과의 교차 엔트로피 손실은 CNN을 통해 역전파되어 특징 임베딩 모델을 업데이트합니다.

일치 네트워크는 다음을 수행할 수 있습니다. 이미지 임베딩을 구축하는 방법을 알아보세요. MN은 카테고리에 대한 특별한 사전 지식 없이도 이 방법을 사용하여 사진을 분류할 수 있습니다. 단순히 클래스의 여러 인스턴스를 비교합니다.

카테고리는 에피소드마다 다르기 때문에 매칭 네트워크는 카테고리 구별에 중요한 이미지 속성(특징)을 계산합니다. 표준 분류를 사용할 때 알고리즘은 각 범주에 고유한 기능을 선택합니다.

프로토타입 네트워크

매칭 네트워크와 유사한 것이 프로토타입 네트워크(PN)입니다. 몇 가지 미묘한 변경을 통해 알고리즘의 성능을 향상시킵니다. PN은 MN보다 더 나은 결과를 얻지만 훈련 프로세스는 기본적으로 지원 세트의 일부 쿼리 이미지 임베딩을 비교하는 것과 동일하지만 프로토타입 네트워크는 다른 전략을 제공합니다.

우리는 PN에서 클래스의 프로토타입을 생성해야 합니다: 클래스의 이미지 임베딩을 평균하여 생성된 클래스 임베딩입니다. 그런 다음 이러한 클래스 프로토타입만 쿼리 이미지 임베딩을 비교하는 데 사용됩니다. 단일 표본 학습 문제에 사용되는 경우 매칭 네트워크와 비슷합니다.

관계 네트워크

관계 네트워크는 위에서 언급한 모든 방법에 대한 연구 결과를 계승한다고 할 수 있습니다. RN은 PN 아이디어를 기반으로 하지만 상당한 알고리즘 개선이 포함되어 있습니다.

이 방법에서 사용하는 거리 함수는 이전 연구처럼 미리 정의하는 것이 아니라 학습이 가능합니다. 관계 모듈은 입력 이미지에서 임베딩 및 클래스 프로토타입을 계산하는 부분인 임베딩 모듈 위에 위치합니다.

훈련 가능한 관계 모듈(거리 함수) 입력은 쿼리 이미지에 각 클래스의 프로토타입을 삽입하고 출력은 각 클래스 일치의 관계 점수입니다. 관계 점수는 Softmax를 통해 전달되어 예측을 얻습니다.

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

Open-AI Clip을 이용한 제로샷 학습

CLIP(Contrastive Language-Image Pre-Training)은 다양한 (이미지, 텍스트) 쌍을 학습한 신경망입니다. 작업에 직접 최적화하지 않고도 주어진 이미지에 대해 가장 관련성이 높은 텍스트 조각을 예측할 수 있습니다(GPT-2 및 3의 제로샷 기능과 유사).

CLIP은 ImageNet "제로 샘플"에서 원본 ResNet50의 성능에 도달할 수 있으며 레이블이 지정된 예제를 사용할 필요가 없습니다. 아래에서는 Pytorch를 사용하여 간단한 분류 모델을 구현합니다.

패키지 소개

! pip install ftfy regex tqdm
 ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
 import torch
 from pkg_resources import packaging
 
 print("Torch version:", torch.__version__)

모델 로드

import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
 model.cuda().eval()
 input_resolution = model.visual.input_resolution
 context_length = model.context_length
 vocab_size = model.vocab_size
 
 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
 print("Input resolution:", input_resolution)
 print("Context length:", context_length)
 print("Vocab size:", vocab_size)

이미지 전처리

8개의 예시 이미지와 텍스트 설명을 모델에 입력하고 해당 기능 간의 유사점을 비교해 보겠습니다.

토크나이저는 대소문자를 구분하지 않으며 적절한 텍스트 설명을 자유롭게 제공할 수 있습니다.

 import os
 import skimage
 import IPython.display
 import matplotlib.pyplot as plt
 from PIL import Image
 import numpy as np
 
 from collections import OrderedDict
 import torch
 
 %matplotlib inline
 %config InlineBackend.figure_format = 'retina'
 
 # images in skimage to use and their textual descriptions
 descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
 }original_images = []
 images = []
 texts = []
 plt.figure(figsize=(16, 5))
 
 for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
 
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
 
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
 
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
 
 plt.tight_layout()

결과의 시각화는 다음과 같습니다.

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

이미지를 정규화하고, 각 텍스트 입력에 레이블을 지정하고, 모델의 순방향 전파를 실행하여 이미지와 텍스트의 특징을 얻습니다.

image_input = torch.tensor(np.stack(images)).cuda()
 text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
 
 with torch.no_grad():

특징을 정규화하고, 각 쌍의 내적을 계산하고, 코사인 유사도 계산을 수행합니다.

 image_features /= image_features.norm(dim=-1, keepdim=True)
 text_features /= text_features.norm(dim=-1, keepdim=True)
 similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
 
 count = len(descriptions)
 
 plt.figure(figsize=(20, 14))
 plt.imshow(similarity, vmin=0.1, vmax=0.3)
 # plt.colorbar()
 plt.yticks(range(count), texts, fontsize=18)
 plt.xticks([])
 for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
 for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
 
 for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
 
 plt.xlim([-0.5, count - 0.5])
 plt.ylim([count + 0.5, -2])
 
 plt.title("Cosine similarity between text and image features", size=20)

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

제로샘플 이미지 분류

 from torchvision.datasets import CIFAR100
 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
 text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
 text_tokens = clip.tokenize(text_descriptions).cuda()
 with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
 
 text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
 plt.figure(figsize=(16, 16))
 for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
 
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
 
 plt.subplots_adjust(wspace=0.5)
 plt.show()

PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류

분류 효과가 여전히 매우 좋은 것을 확인할 수 있습니다.

위 내용은 PyTorch를 사용한 퓨샷 학습을 통한 이미지 분류의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제