Maison > Article > développement back-end > Optimiser vos réseaux de neurones
La semaine dernière, j'ai publié un article sur la façon de construire des réseaux de neurones simples, en particulier des perceptrons multicouches. Cet article approfondira les spécificités des réseaux de neurones pour expliquer comment nous pouvons maximiser les performances d'un réseau de neurones en modifiant ses configurations.
Lors de la formation d'un modèle, vous pourriez penser que si vous entraînez suffisamment votre modèle, celui-ci deviendra parfait. Cela est peut-être vrai, mais cela ne vaut que pour l'ensemble de données sur lequel il a été formé. En fait, si vous lui fournissez un autre ensemble de données dont les valeurs sont différentes, le modèle pourrait produire des prédictions complètement incorrectes.
Pour mieux comprendre cela, disons que vous vous entraînez chaque jour pour votre examen de conduite en conduisant en ligne droite sans bouger le volant. (S'il vous plaît, ne faites pas cela.) Bien que vous obteniez probablement de très bons résultats sur la piste d'accélération, si on vous demandait de tourner à gauche lors de l'examen lui-même, vous pourriez finir par vous transformer en panneau STOP.
Ce phénomène est appelé surapprentissage. Votre modèle peut apprendre tous les aspects et modèles des données sur lesquelles il est formé, mais s'il apprend un modèle qui adhère trop étroitement à l'ensemble de données d'entraînement, alors lorsqu'il reçoit un nouvel ensemble de données, votre modèle fonctionnera mal. Dans le même temps, si vous n’entraînez pas suffisamment votre modèle, celui-ci ne sera pas en mesure de reconnaître correctement les modèles d’autres ensembles de données. Dans ce cas, vous seriez sous-ajusté.
Dans l'exemple ci-dessus, une excellente position pour arrêter l'entraînement de votre modèle serait la bonne lorsque la perte de validation atteint son minimum. Il est possible de le faire avec arrêt précoce, qui arrête l'entraînement une fois qu'il n'y a pas d'amélioration de la perte de validation après un nombre arbitraire de cycles d'entraînement (époques).
L'entraînement de votre modèle consiste à trouver un équilibre entre le surajustement et le sous-apprentissage tout en utilisant un arrêt précoce si nécessaire. C'est pourquoi votre ensemble de données d'entraînement doit être aussi représentatif que possible de votre population globale afin que votre modèle puisse faire des prédictions plus précises sur les données qu'il n'a pas vues.
L'une des configurations d'entraînement les plus importantes pouvant être modifiées est peut-être la fonction de perte, qui est "l'inexactitude" entre les prédictions de votre modèle et leurs valeurs réelles. L'« imprécision » peut être représentée mathématiquement de nombreuses manières différentes, l'une des plus courantes étant l'erreur quadratique moyenne (MSE) :
où yiˉ est la prédiction du modèle et yi est la vraie valeur. Il existe une variante similaire appelée erreur absolue moyenne (MAE)
Quelle est la différence entre ces deux-là et lequel est le meilleur ? La vraie réponse est que cela dépend de divers facteurs. Considérons un exemple simple de régression linéaire bidimensionnelle.
Dans de nombreux cas, il peut y avoir des points de données qui constituent des valeurs aberrantes, des points éloignés des autres points de données. En termes de régression linéaire, cela signifie qu'il y a quelques points sur le xy -avions qui sont loin des autres. Si vous vous souvenez de vos cours de statistiques, ce sont des points comme ceux-ci qui peuvent affecter de manière significative la droite de régression linéaire calculée.
Si vous vouliez penser à une ligne qui pourrait traverser les quatre points, alors y=x serait un excellent choix car cette ligne passerait par tous les points.
Cependant, disons que je décide d'ajouter un autre point à (5,1) . Maintenant, quelle devrait être la droite de régression ? Eh bien, il s'avère que c'est complètement différent : y=0,2x 1,6
Compte tenu des points de données précédents, la ligne s'attendrait à ce que la valeur de y quand x=5 est 5, mais en raison de la valeur aberrante et de son MSE, la ligne de régression est « tirée vers le bas » de manière significative.
Ce n'est qu'un exemple simple, mais cela pose une question à laquelle vous, en tant que développeur d'apprentissage automatique, devez vous arrêter et réfléchir : Dans quelle mesure mon modèle doit-il être sensible aux valeurs aberrantes ? Si vous voulez que votre modèle pour être plus sensible aux valeurs aberrantes, vous choisirez alors une métrique comme MSE, car dans ce cas, les erreurs impliquant des valeurs aberrantes sont plus prononcées en raison de la mise au carré et votre modèle s'ajustera pour minimiser cela. Sinon, vous choisiriez une métrique comme MAE, qui ne se soucie pas autant des valeurs aberrantes.
Dans mon article précédent, j'ai également discuté du concept de rétropropagation, de descente de gradient et de la manière dont ils fonctionnent pour minimiser la perte du modèle. Le gradient est un vecteur qui pointe vers la direction du plus grand changement. Un algorithme de descente de gradient calculera ce vecteur et se déplacera dans la direction exactement opposée afin qu'il atteigne finalement un minimum.
La plupart des optimiseurs ont un taux d'apprentissage spécifique, communément appelé α auxquels ils adhèrent. Essentiellement, cela représente la mesure dans laquelle l'algorithme se déplacera vers le minimum à chaque fois qu'il calculera le gradient. Attention à ce que votre taux d'apprentissage soit trop élevé ! Votre algorithme risque de ne jamais atteindre le minimum en raison des étapes importantes qu'il nécessite et qui pourraient sauter le minimum à plusieurs reprises.
Pour en revenir à la descente de gradient, bien qu'elle soit efficace pour minimiser les pertes, cela pourrait ralentir considérablement le processus de formation car la fonction de perte est calculée sur l'ensemble de l'ensemble de données. Il existe plusieurs alternatives à la descente de pente qui sont plus efficaces mais ont leurs inconvénients respectifs.
L'une des alternatives les plus populaires à la descente de gradient standard est une variante appelée descente de gradient stochastique (SGD). Comme pour la descente de gradient, SGD a un taux d'apprentissage fixe. Mais plutôt que de parcourir l'intégralité de l'ensemble de données comme une descente de gradient, SGD prend un petit échantillon sélectionné au hasard et les poids de votre réseau neuronal sont mis à jour en fonction de l'échantillon. Finalement, les valeurs des paramètres convergent vers un point qui minimise approximativement (mais pas exactement) la fonction de perte. C'est l'un des inconvénients du SGD, car il n'atteint pas toujours le minimum exact. De plus, comme pour la descente de gradient, elle reste sensible au taux d'apprentissage que vous définissez.
Le nom, Adam, est dérivé de estimation adaptative du moment. Il combine essentiellement deux variantes de SGD pour ajuster le taux d'apprentissage pour chaque paramètre d'entrée en fonction de la fréquence à laquelle il est mis à jour au cours de chaque itération d'entraînement (taux d'apprentissage adaptatif). Dans le même temps, il garde également une trace des calculs de gradient passés en tant que moyenne mobile pour lisser les mises à jour (élan). Cependant, en raison de sa caractéristique d'impulsion, sa convergence peut parfois prendre plus de temps que d'autres algorithmes.
Maintenant, un exemple !
J'ai créé un exemple de procédure pas à pas sur Google Colab qui utilise PyTorch pour créer un réseau neuronal qui apprend une relation linéaire simple.
Si vous êtes un peu nouveau sur Python, ne vous inquiétez pas ! J'ai inclus quelques explications qui discutent de ce qui se passe dans chaque section.
Bien que cela ne couvre évidemment pas tout sur l'optimisation des réseaux de neurones, je voulais au moins aborder quelques-uns des concepts les plus importants dont vous pouvez profiter lors de la formation de vos propres modèles. J'espère que vous avez appris quelque chose cette semaine et merci d'avoir lu !
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!