search
HomeTechnology peripheralsAIImage classification with few-shot learning using PyTorch

In recent years, deep learning-based models have performed well in tasks such as target detection and image recognition. On challenging image classification datasets like ImageNet, which contains 1,000 different object classifications, some models now exceed human levels. But these models rely on a supervised training process, they are significantly affected by the availability of labeled training data, and the classes the models are able to detect are limited to the classes they were trained on.

Since there are not enough labeled images for all classes during training, these models may be less useful in real-world settings. And we want the model to be able to recognize classes it has not seen during training, since it is almost impossible to train on images of all potential objects. The problem where we will learn from a few samples is called Few-Shot learning.

What is few-shot learning?

Image classification with few-shot learning using PyTorch

Few-shot learning is a subfield of machine learning. It involves classifying new data with only a few training samples and supervision data. The model we created performs reasonably well with only a small number of training samples.

Consider the following scenario: In the medical field, for some uncommon diseases, there may not be enough x-ray images for training. For such scenarios, building a few-shot learning classifier is the perfect solution.

Changes in Small Samples

Generally speaking, researchers have identified four types:

  1. N-Shot Learning (NSL)
  2. Few-Shot Learning (FSL)
  3. One-Shot Learning (OSL)
  4. Zero-Shot Learning (ZSL)

When we talk about FSL, we usually Refers to the N-way-K-Shot classification. N represents the number of classes, and K represents the number of samples to be trained in each class. So N-Shot Learning is considered as a broader concept than all other concepts. It can be said that Few-Shot, One-Shot and Zero-Shot are sub-fields of NSL. While zero-shot learning aims to classify unseen classes without any training examples.

In One-Shot Learning, there is only one sample for each class. Few-Shot has 2 to 5 samples per class, which means Few-Shot is a more flexible version of One-Shot Learning.

Small sample learning method

Generally, two methods should be considered when solving the Few Shot Learning problem:

Data Level Approach (DLA)

This The strategy is very simple, if there is not enough data to create a solid model and prevent underfitting and overfitting, then more data should be added. Because of this, many FSL problems can be solved by leveraging more data from a larger underlying data set. A notable feature of the base dataset is that it lacks the classes that constitute our support set for the Few-Shot challenge. For example, if we want to classify a certain species of bird, the underlying data set may contain pictures of many other birds.

Parameter-level approach (PLA)

From a parameter-level perspective, Few-Shot Learning samples are relatively easy to overfit because they usually have large high-dimensional spaces. Restricting the parameter space, using regularization and using an appropriate loss function will help solve this problem. A small number of training samples will be used by the model to generalize.

Performance can be improved by guiding the model into a broad parameter space. Normal optimization methods may not produce accurate results due to lack of training data.

For the reasons above, training our model to find the best path through the parameter space produces the best prediction results. This approach is called meta-learning.

Small sample learning image classification algorithm

There are 4 common small sample learning methods:

Model-independent meta-learning Model-Agnostic Meta-Learning

The gradient-based meta-learning (GBML) principle is the basis of MAML. In GBML, meta-learners gain prior experience by training on a base model and learning shared features across all task representations. Each time there is a new task to learn, the meta-learner is fine-tuned using its existing experience and the minimum amount of new training data provided by the new task.

Generally, if we randomly initialize parameters and update them several times, the algorithm will not converge to good performance. MAML attempts to solve this problem. MAML provides a reliable initialization of the meta-parameter learner with only a few gradient steps and guarantees no overfitting, so that new tasks can be optimally and quickly learned.

The steps are as follows:

  1. The meta-learner creates its own copy C at the beginning of each episode,
  2. C is trained on this episode (With the help of base-model),
  3. C makes predictions on the query set,
  4. The loss calculated from these predictions is used to update C,
  5. This This continues until all episodes of training are completed.

Image classification with few-shot learning using PyTorch

The biggest advantage of this technique is that it is considered independent of the choice of meta-learning algorithm. Therefore, MAML methods are widely used in many machine learning algorithms that require rapid adaptation, especially deep neural networks.

Matching Networks

The first metric learning method created to solve the FSL problem was the Matching Network (MN).

A large base data set is required when using the matching network method to solve the Few-Shot Learning problem. .

After dividing the dataset into several episodes, for each episode, the matching network does the following:

  • Every image from the support set and the query set is fed to a CNN that outputs embeddings of features for them
  • Query images using a model trained on the support set to get the cosine distance of the embedded features, classify by softmax
  • Cross entropy loss for classification results by CNN Backpropagation Update Feature Embedding Model

The matching network can learn to build image embeddings in this way. MN is able to classify photos using this method without any special prior knowledge of the categories. He simply compares several instances of the class.

Because categories vary from episode to episode, the matching network computes image attributes (features) that are important for category distinction. When using standard classification, the algorithm selects features that are unique to each category.

Prototypical Networks

Similar to the matching network is the prototype network (PN). It improves the performance of the algorithm through some subtle changes. PN achieves better results than MN, but their training process is essentially the same, just comparing some query image embeddings from the support set, but the prototype network provides different strategies.

We need to create a prototype of the class in PN: the embedding of the class created by averaging the embeddings of the images in the class. Then only these class prototypes are used to compare query image embeddings. When used for single-sample learning problems, it is comparable to matching networks.

Relation Network Relation Network

Relationship network can be said to have inherited the research results of all the methods mentioned above. RN is based on PN ideas but contains significant algorithm improvements.

The distance function used in this method is learnable, rather than defining it in advance like previous studies. The relation module sits on top of the embedding module, which is the part that computes embeddings and class prototypes from the input image.

The trainable relation module (distance function) input is the embedding of the query image and the prototype of each class, and the output is the relation score of each class match. The relation score is passed through Softmax to get a prediction.

Image classification with few-shot learning using PyTorch

Using Open-AI Clip for zero-sample learning

CLIP (Contrastive Language-Image Pre-Training) is a method that can be used in various (images, text) On the trained neural network. It can predict the most relevant text fragments for a given image without being directly optimized for the task (similar to the zero-shot functionality of GPT-2 and 3).

CLIP can achieve the performance of the original ResNet50 on ImageNet "zero samples", and does not require the use of any labeled examples. It overcomes several major challenges in computer vision. Below we use Pytorch to implement a simple Classification model.

Introduction package

! pip install ftfy regex tqdm
 ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
 import torch
 from pkg_resources import packaging
 
 print("Torch version:", torch.__version__)

Loading model

import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
 model.cuda().eval()
 input_resolution = model.visual.input_resolution
 context_length = model.context_length
 vocab_size = model.vocab_size
 
 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
 print("Input resolution:", input_resolution)
 print("Context length:", context_length)
 print("Vocab size:", vocab_size)

Image preprocessing

We will input 8 sample images and their text descriptions into the model and compare the correspondences Similarity between features.

The tokenizer is not case sensitive and we are free to give any suitable text description.

 import os
 import skimage
 import IPython.display
 import matplotlib.pyplot as plt
 from PIL import Image
 import numpy as np
 
 from collections import OrderedDict
 import torch
 
 %matplotlib inline
 %config InlineBackend.figure_format = 'retina'
 
 # images in skimage to use and their textual descriptions
 descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
 }original_images = []
 images = []
 texts = []
 plt.figure(figsize=(16, 5))
 
 for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
 
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
 
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
 
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
 
 plt.tight_layout()

The visualization of the results is as follows:

Image classification with few-shot learning using PyTorch

We normalize the image, label each text input, and run the forward propagation of the model to obtain the image and Characteristics of the text.

image_input = torch.tensor(np.stack(images)).cuda()
 text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
 
 with torch.no_grad():

We normalize the features, calculate the dot product of each pair, and perform cosine similarity calculation

 image_features /= image_features.norm(dim=-1, keepdim=True)
 text_features /= text_features.norm(dim=-1, keepdim=True)
 similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
 
 count = len(descriptions)
 
 plt.figure(figsize=(20, 14))
 plt.imshow(similarity, vmin=0.1, vmax=0.3)
 # plt.colorbar()
 plt.yticks(range(count), texts, fontsize=18)
 plt.xticks([])
 for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
 for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
 
 for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
 
 plt.xlim([-0.5, count - 0.5])
 plt.ylim([count + 0.5, -2])
 
 plt.title("Cosine similarity between text and image features", size=20)

Image classification with few-shot learning using PyTorch

Zero sample image classification

 from torchvision.datasets import CIFAR100
 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
 text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
 text_tokens = clip.tokenize(text_descriptions).cuda()
 with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
 
 text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
 plt.figure(figsize=(16, 16))
 for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
 
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
 
 plt.subplots_adjust(wspace=0.5)
 plt.show()

Image classification with few-shot learning using PyTorch

It can be seen that the classification effect is still very good.

The above is the detailed content of Image classification with few-shot learning using PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Statement
This article is reproduced at:51CTO.COM. If there is any infringement, please contact admin@php.cn delete
Gemma Scope: Google's Microscope for Peering into AI's Thought ProcessGemma Scope: Google's Microscope for Peering into AI's Thought ProcessApr 17, 2025 am 11:55 AM

Exploring the Inner Workings of Language Models with Gemma Scope Understanding the complexities of AI language models is a significant challenge. Google's release of Gemma Scope, a comprehensive toolkit, offers researchers a powerful way to delve in

Who Is a Business Intelligence Analyst and How To Become One?Who Is a Business Intelligence Analyst and How To Become One?Apr 17, 2025 am 11:44 AM

Unlocking Business Success: A Guide to Becoming a Business Intelligence Analyst Imagine transforming raw data into actionable insights that drive organizational growth. This is the power of a Business Intelligence (BI) Analyst – a crucial role in gu

How to Add a Column in SQL? - Analytics VidhyaHow to Add a Column in SQL? - Analytics VidhyaApr 17, 2025 am 11:43 AM

SQL's ALTER TABLE Statement: Dynamically Adding Columns to Your Database In data management, SQL's adaptability is crucial. Need to adjust your database structure on the fly? The ALTER TABLE statement is your solution. This guide details adding colu

Business Analyst vs. Data AnalystBusiness Analyst vs. Data AnalystApr 17, 2025 am 11:38 AM

Introduction Imagine a bustling office where two professionals collaborate on a critical project. The business analyst focuses on the company's objectives, identifying areas for improvement, and ensuring strategic alignment with market trends. Simu

What are COUNT and COUNTA in Excel? - Analytics VidhyaWhat are COUNT and COUNTA in Excel? - Analytics VidhyaApr 17, 2025 am 11:34 AM

Excel data counting and analysis: detailed explanation of COUNT and COUNTA functions Accurate data counting and analysis are critical in Excel, especially when working with large data sets. Excel provides a variety of functions to achieve this, with the COUNT and COUNTA functions being key tools for counting the number of cells under different conditions. Although both functions are used to count cells, their design targets are targeted at different data types. Let's dig into the specific details of COUNT and COUNTA functions, highlight their unique features and differences, and learn how to apply them in data analysis. Overview of key points Understand COUNT and COU

Chrome is Here With AI: Experiencing Something New Everyday!!Chrome is Here With AI: Experiencing Something New Everyday!!Apr 17, 2025 am 11:29 AM

Google Chrome's AI Revolution: A Personalized and Efficient Browsing Experience Artificial Intelligence (AI) is rapidly transforming our daily lives, and Google Chrome is leading the charge in the web browsing arena. This article explores the exciti

AI's Human Side: Wellbeing And The Quadruple Bottom LineAI's Human Side: Wellbeing And The Quadruple Bottom LineApr 17, 2025 am 11:28 AM

Reimagining Impact: The Quadruple Bottom Line For too long, the conversation has been dominated by a narrow view of AI’s impact, primarily focused on the bottom line of profit. However, a more holistic approach recognizes the interconnectedness of bu

5 Game-Changing Quantum Computing Use Cases You Should Know About5 Game-Changing Quantum Computing Use Cases You Should Know AboutApr 17, 2025 am 11:24 AM

Things are moving steadily towards that point. The investment pouring into quantum service providers and startups shows that industry understands its significance. And a growing number of real-world use cases are emerging to demonstrate its value out

See all articles

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

AI Hentai Generator

AI Hentai Generator

Generate AI Hentai for free.

Hot Article

R.E.P.O. Energy Crystals Explained and What They Do (Yellow Crystal)
1 months agoBy尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Best Graphic Settings
1 months agoBy尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. How to Fix Audio if You Can't Hear Anyone
1 months agoBy尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Chat Commands and How to Use Them
1 months agoBy尊渡假赌尊渡假赌尊渡假赌

Hot Tools

SublimeText3 English version

SublimeText3 English version

Recommended: Win version, supports code prompts!

SecLists

SecLists

SecLists is the ultimate security tester's companion. It is a collection of various types of lists that are frequently used during security assessments, all in one place. SecLists helps make security testing more efficient and productive by conveniently providing all the lists a security tester might need. List types include usernames, passwords, URLs, fuzzing payloads, sensitive data patterns, web shells, and more. The tester can simply pull this repository onto a new test machine and he will have access to every type of list he needs.

SAP NetWeaver Server Adapter for Eclipse

SAP NetWeaver Server Adapter for Eclipse

Integrate Eclipse with SAP NetWeaver application server.

VSCode Windows 64-bit Download

VSCode Windows 64-bit Download

A free and powerful IDE editor launched by Microsoft

EditPlus Chinese cracked version

EditPlus Chinese cracked version

Small size, syntax highlighting, does not support code prompt function