Home >Technology peripherals >AI >Image classification with few-shot learning using PyTorch

Image classification with few-shot learning using PyTorch

WBOY
WBOYforward
2023-04-09 10:51:051477browse

In recent years, deep learning-based models have performed well in tasks such as target detection and image recognition. On challenging image classification datasets like ImageNet, which contains 1,000 different object classifications, some models now exceed human levels. But these models rely on a supervised training process, they are significantly affected by the availability of labeled training data, and the classes the models are able to detect are limited to the classes they were trained on.

Since there are not enough labeled images for all classes during training, these models may be less useful in real-world settings. And we want the model to be able to recognize classes it has not seen during training, since it is almost impossible to train on images of all potential objects. The problem where we will learn from a few samples is called Few-Shot learning.

What is few-shot learning?

Image classification with few-shot learning using PyTorch

Few-shot learning is a subfield of machine learning. It involves classifying new data with only a few training samples and supervision data. The model we created performs reasonably well with only a small number of training samples.

Consider the following scenario: In the medical field, for some uncommon diseases, there may not be enough x-ray images for training. For such scenarios, building a few-shot learning classifier is the perfect solution.

Changes in Small Samples

Generally speaking, researchers have identified four types:

  1. N-Shot Learning (NSL)
  2. Few-Shot Learning (FSL)
  3. One-Shot Learning (OSL)
  4. Zero-Shot Learning (ZSL)

When we talk about FSL, we usually Refers to the N-way-K-Shot classification. N represents the number of classes, and K represents the number of samples to be trained in each class. So N-Shot Learning is considered as a broader concept than all other concepts. It can be said that Few-Shot, One-Shot and Zero-Shot are sub-fields of NSL. While zero-shot learning aims to classify unseen classes without any training examples.

In One-Shot Learning, there is only one sample for each class. Few-Shot has 2 to 5 samples per class, which means Few-Shot is a more flexible version of One-Shot Learning.

Small sample learning method

Generally, two methods should be considered when solving the Few Shot Learning problem:

Data Level Approach (DLA)

This The strategy is very simple, if there is not enough data to create a solid model and prevent underfitting and overfitting, then more data should be added. Because of this, many FSL problems can be solved by leveraging more data from a larger underlying data set. A notable feature of the base dataset is that it lacks the classes that constitute our support set for the Few-Shot challenge. For example, if we want to classify a certain species of bird, the underlying data set may contain pictures of many other birds.

Parameter-level approach (PLA)

From a parameter-level perspective, Few-Shot Learning samples are relatively easy to overfit because they usually have large high-dimensional spaces. Restricting the parameter space, using regularization and using an appropriate loss function will help solve this problem. A small number of training samples will be used by the model to generalize.

Performance can be improved by guiding the model into a broad parameter space. Normal optimization methods may not produce accurate results due to lack of training data.

For the reasons above, training our model to find the best path through the parameter space produces the best prediction results. This approach is called meta-learning.

Small sample learning image classification algorithm

There are 4 common small sample learning methods:

Model-independent meta-learning Model-Agnostic Meta-Learning

The gradient-based meta-learning (GBML) principle is the basis of MAML. In GBML, meta-learners gain prior experience by training on a base model and learning shared features across all task representations. Each time there is a new task to learn, the meta-learner is fine-tuned using its existing experience and the minimum amount of new training data provided by the new task.

Generally, if we randomly initialize parameters and update them several times, the algorithm will not converge to good performance. MAML attempts to solve this problem. MAML provides a reliable initialization of the meta-parameter learner with only a few gradient steps and guarantees no overfitting, so that new tasks can be optimally and quickly learned.

The steps are as follows:

  1. The meta-learner creates its own copy C at the beginning of each episode,
  2. C is trained on this episode (With the help of base-model),
  3. C makes predictions on the query set,
  4. The loss calculated from these predictions is used to update C,
  5. This This continues until all episodes of training are completed.

Image classification with few-shot learning using PyTorch

The biggest advantage of this technique is that it is considered independent of the choice of meta-learning algorithm. Therefore, MAML methods are widely used in many machine learning algorithms that require rapid adaptation, especially deep neural networks.

Matching Networks

The first metric learning method created to solve the FSL problem was the Matching Network (MN).

A large base data set is required when using the matching network method to solve the Few-Shot Learning problem. .

After dividing the dataset into several episodes, for each episode, the matching network does the following:

  • Every image from the support set and the query set is fed to a CNN that outputs embeddings of features for them
  • Query images using a model trained on the support set to get the cosine distance of the embedded features, classify by softmax
  • Cross entropy loss for classification results by CNN Backpropagation Update Feature Embedding Model

The matching network can learn to build image embeddings in this way. MN is able to classify photos using this method without any special prior knowledge of the categories. He simply compares several instances of the class.

Because categories vary from episode to episode, the matching network computes image attributes (features) that are important for category distinction. When using standard classification, the algorithm selects features that are unique to each category.

Prototypical Networks

Similar to the matching network is the prototype network (PN). It improves the performance of the algorithm through some subtle changes. PN achieves better results than MN, but their training process is essentially the same, just comparing some query image embeddings from the support set, but the prototype network provides different strategies.

We need to create a prototype of the class in PN: the embedding of the class created by averaging the embeddings of the images in the class. Then only these class prototypes are used to compare query image embeddings. When used for single-sample learning problems, it is comparable to matching networks.

Relation Network Relation Network

Relationship network can be said to have inherited the research results of all the methods mentioned above. RN is based on PN ideas but contains significant algorithm improvements.

The distance function used in this method is learnable, rather than defining it in advance like previous studies. The relation module sits on top of the embedding module, which is the part that computes embeddings and class prototypes from the input image.

The trainable relation module (distance function) input is the embedding of the query image and the prototype of each class, and the output is the relation score of each class match. The relation score is passed through Softmax to get a prediction.

Image classification with few-shot learning using PyTorch

Using Open-AI Clip for zero-sample learning

CLIP (Contrastive Language-Image Pre-Training) is a method that can be used in various (images, text) On the trained neural network. It can predict the most relevant text fragments for a given image without being directly optimized for the task (similar to the zero-shot functionality of GPT-2 and 3).

CLIP can achieve the performance of the original ResNet50 on ImageNet "zero samples", and does not require the use of any labeled examples. It overcomes several major challenges in computer vision. Below we use Pytorch to implement a simple Classification model.

Introduction package

! pip install ftfy regex tqdm
 ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
 import torch
 from pkg_resources import packaging
 
 print("Torch version:", torch.__version__)

Loading model

import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
 model.cuda().eval()
 input_resolution = model.visual.input_resolution
 context_length = model.context_length
 vocab_size = model.vocab_size
 
 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
 print("Input resolution:", input_resolution)
 print("Context length:", context_length)
 print("Vocab size:", vocab_size)

Image preprocessing

We will input 8 sample images and their text descriptions into the model and compare the correspondences Similarity between features.

The tokenizer is not case sensitive and we are free to give any suitable text description.

 import os
 import skimage
 import IPython.display
 import matplotlib.pyplot as plt
 from PIL import Image
 import numpy as np
 
 from collections import OrderedDict
 import torch
 
 %matplotlib inline
 %config InlineBackend.figure_format = 'retina'
 
 # images in skimage to use and their textual descriptions
 descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
 }original_images = []
 images = []
 texts = []
 plt.figure(figsize=(16, 5))
 
 for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
 
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
 
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
 
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
 
 plt.tight_layout()

The visualization of the results is as follows:

Image classification with few-shot learning using PyTorch

We normalize the image, label each text input, and run the forward propagation of the model to obtain the image and Characteristics of the text.

image_input = torch.tensor(np.stack(images)).cuda()
 text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
 
 with torch.no_grad():

We normalize the features, calculate the dot product of each pair, and perform cosine similarity calculation

 image_features /= image_features.norm(dim=-1, keepdim=True)
 text_features /= text_features.norm(dim=-1, keepdim=True)
 similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
 
 count = len(descriptions)
 
 plt.figure(figsize=(20, 14))
 plt.imshow(similarity, vmin=0.1, vmax=0.3)
 # plt.colorbar()
 plt.yticks(range(count), texts, fontsize=18)
 plt.xticks([])
 for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
 for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
 
 for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
 
 plt.xlim([-0.5, count - 0.5])
 plt.ylim([count + 0.5, -2])
 
 plt.title("Cosine similarity between text and image features", size=20)

Image classification with few-shot learning using PyTorch

Zero sample image classification

 from torchvision.datasets import CIFAR100
 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
 text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
 text_tokens = clip.tokenize(text_descriptions).cuda()
 with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
 
 text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
 plt.figure(figsize=(16, 16))
 for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
 
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
 
 plt.subplots_adjust(wspace=0.5)
 plt.show()

Image classification with few-shot learning using PyTorch

It can be seen that the classification effect is still very good.

The above is the detailed content of Image classification with few-shot learning using PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete