Maison >Périphériques technologiques >IA >Classification d'images avec apprentissage en quelques prises de vue à l'aide de PyTorch
Ces dernières années, les modèles basés sur l'apprentissage profond ont donné de bons résultats dans des tâches telles que la détection d'objets et la reconnaissance d'images. Sur des ensembles de données de classification d'images complexes comme ImageNet, qui contient 1 000 classifications d'objets différentes, certains modèles dépassent désormais les niveaux humains. Mais ces modèles s'appuient sur un processus de formation supervisé, ils sont considérablement affectés par la disponibilité de données de formation étiquetées, et les classes que les modèles sont capables de détecter sont limitées aux classes sur lesquelles ils ont été formés.
Comme il n'y a pas suffisamment d'images étiquetées pour toutes les classes pendant la formation, ces modèles peuvent être moins utiles dans des contextes réels. Et nous voulons que le modèle soit capable de reconnaître les classes qu'il n'a pas vues lors de l'entraînement, car il est presque impossible de s'entraîner sur des images de tous les objets potentiels. Le problème pour lequel nous apprendrons à partir de quelques échantillons est appelé apprentissage en quelques coups.
L'apprentissage en quelques étapes est un sous-domaine de l'apprentissage automatique. Cela implique de classer de nouvelles données avec seulement quelques échantillons de formation et données de supervision. Le modèle que nous avons créé fonctionne raisonnablement bien avec seulement un petit nombre d'échantillons d'apprentissage.
Considérez le scénario suivant : Dans le domaine médical, pour certaines maladies rares, il se peut qu'il n'y ait pas suffisamment d'images radiographiques pour la formation. Pour de tels scénarios, la création d’un classificateur d’apprentissage en quelques étapes est la solution parfaite.
Généralement, les chercheurs ont identifié quatre types :
Quand on parle de FSL, on fait généralement référence à la classification N-way-K-Shot. N représente le nombre de classes et K représente le nombre d'échantillons à former dans chaque classe. Ainsi, N-Shot Learning est considéré comme un concept plus large que tous les autres concepts. On peut dire que Few-Shot, One-Shot et Zero-Shot sont des sous-domaines de NSL. Alors que l'apprentissage zéro-shot vise à classer les classes invisibles sans aucun exemple de formation.
Dans One-Shot Learning, il n'y a qu'un seul échantillon par classe. Few-Shot propose 2 à 5 échantillons par classe, ce qui signifie que Few-Shot est une version plus flexible de One-Shot Learning.
Généralement, deux méthodes doivent être envisagées lors de la résolution du problème d'apprentissage par quelques coups :
Cette stratégie est très simple, s'il n'y a pas suffisamment de données pour créer un modèle solide et éviter les sous-ajustements et les surajustements, alors davantage de données doivent être ajoutées. Pour cette raison, de nombreux problèmes de FLS peuvent être résolus en exploitant davantage de données provenant d’un ensemble de données sous-jacentes plus vaste. Une caractéristique notable de l'ensemble de données de base est qu'il lui manque les classes qui constituent notre ensemble de support pour le défi Few-Shot. Par exemple, si nous souhaitons classer une certaine espèce d’oiseau, l’ensemble de données sous-jacent peut contenir des images de nombreux autres oiseaux.
Du point de vue du niveau des paramètres, les échantillons d'apprentissage par quelques tirs sont relativement faciles à surajuster car ils ont généralement de grands espaces de grande dimension. Restreindre l'espace des paramètres, utiliser la régularisation et utiliser une fonction de perte appropriée aideront à résoudre ce problème. Un petit nombre d'échantillons d'apprentissage seront utilisés par le modèle pour généraliser.
Les performances peuvent être améliorées en guidant le modèle dans un large espace de paramètres. Les méthodes d'optimisation normales peuvent ne pas produire de résultats précis en raison du manque de données d'entraînement.
Pour les raisons ci-dessus, entraîner notre modèle pour trouver le meilleur chemin à travers l'espace des paramètres produit les meilleurs résultats de prédiction. Cette approche est appelée méta-apprentissage.
Il existe 4 méthodes courantes d'apprentissage pour petits échantillons :
Méta-apprentissage basé sur les gradients (GBML) Le principe est la base MAML. En GBML, les méta-apprenants acquièrent une expérience préalable grâce à la formation du modèle de base et à l'apprentissage des fonctionnalités partagées dans toutes les représentations de tâches. Chaque fois qu'il y a une nouvelle tâche à apprendre, le méta-apprenant est affiné en utilisant son expérience existante et la quantité minimale de nouvelles données de formation fournies par la nouvelle tâche.
Généralement, si nous initialisons les paramètres de manière aléatoire et les mettons à jour plusieurs fois, l'algorithme ne convergera pas vers de bonnes performances. MAML tente de résoudre ce problème. MAML fournit une initialisation fiable de l'apprenant des métaparamètres avec seulement quelques étapes de gradient et sans surapprentissage, afin que de nouvelles tâches puissent être apprises de manière optimale et rapide.
Les étapes sont les suivantes :
Le plus gros avantage de cette technique est qu'elle est considérée comme indépendante du choix de l'algorithme de méta-apprentissage. Par conséquent, les méthodes MAML sont largement utilisées dans de nombreux algorithmes d’apprentissage automatique qui nécessitent une adaptation rapide, notamment dans les réseaux de neurones profonds.
La première méthode d'apprentissage métrique créée pour résoudre le problème FSL est le Matching Network (MN).
Un grand ensemble de données de base est requis lors de l'utilisation de la méthode de réseau de correspondance pour résoudre le problème d'apprentissage en quelques coups. .
Après avoir divisé cet ensemble de données en plusieurs épisodes, pour chaque épisode, le réseau de correspondance effectue les opérations suivantes :
Le réseau correspondant peut être utilisé de cette manière Apprenez à créer des intégrations d'images. MN est capable de classer les photos en utilisant cette méthode sans aucune connaissance préalable particulière des catégories. Il compare simplement plusieurs instances de la classe.
Étant donné que les catégories varient d'un épisode à l'autre, le réseau de correspondance calcule les attributs d'image (caractéristiques) qui sont importants pour la distinction des catégories. Lors de l'utilisation de la classification standard, l'algorithme sélectionne les caractéristiques uniques à chaque catégorie.
Le réseau prototypique (PN) est similaire au réseau correspondant. Il améliore les performances de l'algorithme grâce à quelques changements subtils. PN obtient de meilleurs résultats que MN, mais leur processus de formation est essentiellement le même, il suffit de comparer certaines intégrations d'images de requête de l'ensemble de support, mais le réseau prototype propose des stratégies différentes.
Nous devons créer un prototype de la classe en PN : l'intégration de la classe créée en faisant la moyenne des intégrations des images dans la classe. Seuls ces prototypes de classe sont ensuite utilisés pour comparer les intégrations d’images de requête. Lorsqu'il est utilisé pour des problèmes d'apprentissage à échantillon unique, il est comparable aux réseaux d'appariement.
Le réseau relationnel peut être considéré comme héritant des résultats de la recherche sur toutes les méthodes mentionnées ci-dessus. RN est basé sur les idées de PN mais contient des améliorations significatives de l'algorithme.
La fonction de distance utilisée par cette méthode est apprenable, plutôt que de la définir à l'avance comme les études précédentes. Le module de relation se trouve au-dessus du module d'intégration, qui est la partie qui calcule les intégrations et les prototypes de classe à partir de l'image d'entrée.
L'entrée du module de relation entraînable (fonction de distance) est l'intégration de l'image de requête avec le prototype de chaque classe, et la sortie est le score de relation de chaque correspondance de classe. Le score de relation est transmis via Softmax pour obtenir une prédiction.
CLIP (Contrastive Language-Image Pre-Training) est un réseau de neurones entraîné sur diverses paires (image, texte). Il peut prédire les fragments de texte les plus pertinents pour une image donnée sans être directement optimisé pour la tâche (similaire à la fonctionnalité zéro tir de GPT-2 et 3).
CLIP peut atteindre les performances du ResNet50 original sur ImageNet "zéro échantillon" et ne nécessite l'utilisation d'aucun exemple étiqueté. Il surmonte plusieurs défis majeurs en vision par ordinateur. Ci-dessous, nous utilisons Pytorch pour implémenter un modèle de classification simple.
Présentation du package
! pip install ftfy regex tqdm ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np import torch from pkg_resources import packaging print("Torch version:", torch.__version__)
Chargement du modèle
import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32") model.cuda().eval() input_resolution = model.visual.input_resolution context_length = model.context_length vocab_size = model.vocab_size print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") print("Input resolution:", input_resolution) print("Context length:", context_length) print("Vocab size:", vocab_size)
Prétraitement des images
Nous saisirons 8 exemples d'images et leurs descriptions textuelles dans le modèle et comparerons les similitudes entre les fonctionnalités correspondantes.
Le tokenizer n'est pas sensible à la casse et nous sommes libres de donner toute description textuelle appropriée.
import os import skimage import IPython.display import matplotlib.pyplot as plt from PIL import Image import numpy as np from collections import OrderedDict import torch %matplotlib inline %config InlineBackend.figure_format = 'retina' # images in skimage to use and their textual descriptions descriptions = { "page": "a page of text about segmentation", "chelsea": "a facial photo of a tabby cat", "astronaut": "a portrait of an astronaut with the American flag", "rocket": "a rocket standing on a launchpad", "motorcycle_right": "a red motorcycle standing in a garage", "camera": "a person looking at a camera on a tripod", "horse": "a black-and-white silhouette of a horse", "coffee": "a cup of coffee on a saucer" }original_images = [] images = [] texts = [] plt.figure(figsize=(16, 5)) for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]: name = os.path.splitext(filename)[0] if name not in descriptions: continue image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") plt.subplot(2, 4, len(images) + 1) plt.imshow(image) plt.title(f"{filename}n{descriptions[name]}") plt.xticks([]) plt.yticks([]) original_images.append(image) images.append(preprocess(image)) texts.append(descriptions[name]) plt.tight_layout()
La visualisation des résultats est la suivante :
Nous normalisons les images, étiquetons chaque entrée de texte et exécutons la propagation avant du modèle pour obtenir les caractéristiques des images et du texte.
image_input = torch.tensor(np.stack(images)).cuda() text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda() with torch.no_grad():
Nous normalisons les caractéristiques, calculons le produit scalaire de chaque paire et effectuons un calcul de similarité cosinus
image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T count = len(descriptions) plt.figure(figsize=(20, 14)) plt.imshow(similarity, vmin=0.1, vmax=0.3) # plt.colorbar() plt.yticks(range(count), texts, fontsize=18) plt.xticks([]) for i, image in enumerate(original_images): plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower") for x in range(similarity.shape[1]): for y in range(similarity.shape[0]): plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12) for side in ["left", "top", "right", "bottom"]: plt.gca().spines[side].set_visible(False) plt.xlim([-0.5, count - 0.5]) plt.ylim([count + 0.5, -2]) plt.title("Cosine similarity between text and image features", size=20)
Classification d'image à échantillon zéro
from torchvision.datasets import CIFAR100 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True) text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes] text_tokens = clip.tokenize(text_descriptions).cuda() with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) top_probs, top_labels = text_probs.cpu().topk(5, dim=-1) plt.figure(figsize=(16, 16)) for i, image in enumerate(original_images): plt.subplot(4, 4, 2 * i + 1) plt.imshow(image) plt.axis("off") plt.subplot(4, 4, 2 * i + 2) y = np.arange(top_probs.shape[-1]) plt.grid() plt.barh(y, top_probs[i]) plt.gca().invert_yaxis() plt.gca().set_axisbelow(True) plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) plt.xlabel("probability") plt.subplots_adjust(wspace=0.5) plt.show()
Vous pouvez voir que l'effet de classification est toujours très bon OK
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!