Maison >Périphériques technologiques >IA >Dash, le cadre d'apprentissage semi-supervisé open source de Damo Academy, actualise de nombreux SOTA
Apprentissage supervisé
Nous savons que le but de la formation de modèles est en fait d'apprendre une fonction de prédiction. Mathématiquement, cela peut être décrit comme un apprentissage à partir de données (fonction de cartographie. de X) à l’annotation (y). L'apprentissage supervisé est l'une des méthodes de formation de modèles les plus couramment utilisées. L'amélioration de son effet dépend d'une grande quantité de données de formation bien étiquetées, ce qu'on appelle une grande quantité de données étiquetées ((X, y)). Cependant, l'étiquetage des données nécessite souvent beaucoup de main d'œuvre et de ressources matérielles, etc. Par conséquent, même si l'effet est amélioré, cela entraînera également le problème du coût élevé. Dans les applications pratiques, on constate souvent qu'il y a une petite quantité de données étiquetées et une grande quantité de données non étiquetées. L'apprentissage semi-supervisé qui en résulte a également attiré de plus en plus l'attention des chercheurs.
Apprentissage semi-supervisé
L'apprentissage semi-supervisé apprend simultanément une petite quantité de données étiquetées et une grande quantité de données non étiquetées, et son objectif est d'utiliser des données non étiquetées pour améliorer le modèle Précision. Par exemple, l'auto-formation est une méthode d'apprentissage semi-supervisé très courante. Son processus spécifique consiste à apprendre le mappage des données de X à y pour les données étiquetées (X, y), et en même temps à utiliser le modèle appris pour prédire. une pseudo-étiquette X de données non étiquetées aide le modèle à obtenir une meilleure convergence et à améliorer la précision en effectuant davantage d'apprentissage supervisé sur les données de pseudo-étiquette (X, ).
Résolution des problèmes de base
Le cadre d'apprentissage semi-supervisé existant peut être grossièrement divisé en deux types : l'un consiste à participer à la formation dans son intégralité et l'autre à utiliser un programme fixe. seuil Retirez les échantillons avec un niveau de confiance plus élevé pour la formation (comme FixMatch). Étant donné que l'utilisation de données non étiquetées par apprentissage semi-supervisé dépend des pseudo-étiquettes prédites par le modèle actuel, la précision des pseudo-étiquettes aura un plus grand impact sur la formation du modèle. De bons résultats de prédiction contribueront à la convergence et à la précision de. le modèle. Pour l’apprentissage d’un nouveau modèle, de mauvais résultats de prédiction interféreront avec la formation du modèle. Nous pensons donc : Tous les échantillons non étiquetés ne sont pas nécessaires !
Cet article propose de manière innovante l'utilisation du seuil dynamique (seuil dynamique) comme méthode de sélection d'échantillons non étiquetés pour l'apprentissage semi-supervisé (SSL), nous avons transformé le cadre de formation de l'apprentissage semi-supervisé et amélioré la stratégie de sélection d'échantillons non étiquetés pendant le processus de formation en modifiant dynamiquement les seuils pour sélectionner des échantillons non étiquetés plus efficaces pour la formation. Dash est une stratégie générale qui peut être facilement intégrée aux méthodes d’apprentissage semi-supervisées existantes. Expérimentalement, nous avons pleinement vérifié son efficacité sur des ensembles de données standards tels que CIFAR-10, CIFAR-100, STL-10 et SVHN. En théorie, l'article prouve les propriétés de convergence de l'algorithme Dash du point de vue de l'optimisation non convexe.
Cadre de formation Fixmatch
Avant de présenter notre méthode Dash, nous présentons l'algorithme FixMatch proposé par Google, une méthode d'apprentissage semi-supervisé qui utilise un seuil fixe pour sélectionner des échantillons non étiquetés. Le cadre de formation FixMatch était la précédente solution SOTA. Les points clés de l'ensemble du cadre d'apprentissage peuvent être résumés comme les points suivants :
1. Pour les échantillons obtenus par une faible amélioration des données (retournement horizontal, décalage, etc.) de données non étiquetées, la valeur prédite est obtenue via le courant. modèle
2 . Pour les données non étiquetées, les échantillons obtenus grâce à une forte amélioration des données (RA ou CTA) sont utilisés pour obtenir des valeurs prédites grâce au modèle actuel
3. avec une grande confiance sont formés grâce à une méthode chaude Pseudo-étiquette , puis utilisez et la valeur prédite de X obtenue grâce à une forte amélioration des données pour entraîner le modèle. L'avantage de
fixmatch est qu'il utilise des données faiblement améliorées pour prédire les pseudo-étiquettes, ce qui augmente la précision des prédictions de pseudo-étiquettes, et utilise un seuil fixe de 0,95 (correspondant à une perte de 0,0513) pendant le processus de formation pour sélectionner un niveau de confiance élevé (le seuil est supérieur à 0,95, c'est-à-dire les échantillons prédits avec une perte inférieure ou égale à 0,0513) génèrent des pseudo-étiquettes, stabilisant davantage le processus de formation.
Cadre de formation Dash
Visant le problème de la sélection de tous les pseudo-étiquettes et de l'utilisation de seuils fixes pour sélectionner les pseudo-étiquettes, nous proposons de manière innovante une stratégie d'utilisation de seuils dynamiques pour le criblage d'échantillons. Autrement dit, le seuil dynamique est atténué avec t
où C=1,0001, est la perte moyenne de données étiquetées après la première époque, et nous sélectionnons ces échantillons non étiquetés pour participer à la rétropropagation du gradient. La figure ci-dessous montre la courbe de variation du seuil sous différentes valeurs . Vous pouvez voir que le paramètre contrôle le taux de déclin de la courbe de seuil. La courbe de changement de est similaire à la tendance décroissante de la fonction de perte lors de la simulation du modèle d'entraînement.
La figure suivante compare les changements dans le nombre d'échantillons corrects et le nombre d'échantillons incorrects sélectionnés par FixMath et Dash au cours du processus de formation (l'ensemble de données utilisé est cifar100). Il ressort clairement de la figure que par rapport à FixMatch, Dash peut sélectionner plus d'échantillons avec des étiquettes correctes et moins d'échantillons avec des étiquettes incorrectes, ce qui contribue finalement à améliorer la précision du modèle de formation.
Notre algorithme peut être résumé comme suit Algorithme 1. Dash est une stratégie générale qui peut être facilement intégrée aux méthodes d’apprentissage semi-supervisées existantes. Pour plus de commodité, dans les expériences de cet article, nous intégrons principalement Dash avec FixMatch. Pour des preuves plus théoriques, voir l'article.
Nous avons vérifié l'algorithme sur des ensembles de données d'apprentissage semi-supervisé couramment utilisés : CIFAR-10, CIFAR-100, STL-10 et SVHN. Les résultats sont les suivants :
On peut voir que notre méthode a obtenu de meilleurs résultats que SOTA dans plusieurs contextes expérimentaux. Ce qui doit être expliqué, c'est l'expérience pour le label CIFAR-100 400, ReMixMatch Using. l'astuce supplémentaire d'alignement des données permet d'obtenir de meilleurs résultats. Après avoir ajouté l'astuce d'alignement des données à Dash, un taux d'erreur de 43,31 % peut être atteint, ce qui est inférieur au taux d'erreur de 44,28 % de ReMixMatch.
Dans le processus de développement actuel de modèles orientés vers des domaines de tâches, le framework Dash semi-supervisé est souvent appliqué. Ensuite, j'aimerais vous présenter les modèles gratuits open source que nous avons développés dans divers domaines. Vous êtes invités à les expérimenter et à les télécharger (vous pouvez en faire l'expérience sur la plupart des téléphones mobiles) :
.Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!