Maison  >  Article  >  Périphériques technologiques  >  Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace l'état caché RNN

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace l'état caché RNN

WBOY
WBOYoriginal
2024-07-17 16:08:17439parcourir

Les performances des grands modèles ont été améliorées de 125M à 1,3B.


Incroyable, cela est finalement arrivé.

Une nouvelle architecture de grand modèle de langage (LLM) devrait remplacer Transformer, qui a été populaire jusqu'à présent dans le domaine de l'IA, et ses performances sont meilleures que celles de Mamba. Lundi, un article sur la formation au test (TTT) est devenu un sujet brûlant dans la communauté de l'intelligence artificielle.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Lien papier : https://arxiv.org/abs/2407.04620

Les auteurs de cette étude sont issus de l'Université de Stanford, de l'Université de Californie, de Berkeley, de l'Université de Californie, de San Diego et de Meta. Ils ont conçu une nouvelle architecture, TTT, qui a remplacé l'état caché du RNN par un modèle d'apprentissage automatique. Le modèle compresse le contexte grâce à une descente de gradient réelle des jetons d'entrée.

Karan Dalal, l'un des auteurs de l'étude, a déclaré qu'il pensait que cela changerait fondamentalement l'approche du modèle linguistique.
Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN
Dans les modèles d'apprentissage automatique, la couche TTT remplace directement Attention et déverrouille l'architecture de complexité linéaire grâce à la mémoire expressive, nous permettant d'entraîner LLM avec des millions (parfois des milliards) de jetons en contexte.

L'auteur a mené une série de comparaisons sur de grands modèles avec des tailles de paramètres allant de 125M à 1,3B et a constaté que TTT-Linear et TTT-MLP peuvent égaler ou vaincre les méthodes d'architecture Transformers et Mamba les plus puissantes.

En tant que nouveau mécanisme de compression d'informations et de mémoire de modèle, la couche TTT peut simplement et directement remplacer la couche d'auto-attention dans Transformer.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Comparé à Mamba, TTT-Linear a moins de perplexité, moins de FLOP (à gauche) et une meilleure utilisation des contextes longs (à droite) :

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Ce n'est pas seulement linéaire en termes de complexité théorique, mais aussi en termes de fonctionnement réel. le temps est également plus rapide.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

  • Après la mise en ligne de l'article, l'auteur a rendu le code et jax publics pour que les gens puissent s'entraîner et tester : https://github.com/test-time-training/ttt-lm-jax
  • Aussi le code d'inférence PyTorch : https://github.com/test-time-training/ttt-lm-pytorch

Introduction à la méthode

Le défi du contexte long est intrinsèque au nature des couches RNN : contrairement au mécanisme d'auto-attention, la couche RNN doit compresser le contexte dans un état caché de taille fixe, et les règles de mise à jour doivent découvrir la structure sous-jacente et les relations entre des milliers, voire des millions de jetons.

L'équipe de recherche a d'abord observé que l'apprentissage auto-supervisé peut compresser de grands ensembles d'entraînement en poids pour des modèles tels que LLM, et que les modèles LLM présentent souvent une compréhension approfondie des connexions sémantiques entre leurs données d'entraînement.

Inspirée par cette observation, l'équipe de recherche a conçu une nouvelle classe de couches de modélisation de séquences, où l'état caché est un modèle et la règle de mise à jour est une étape de l'apprentissage auto-supervisé. Étant donné que le processus de mise à jour de l’état caché sur la séquence de test équivaut à entraîner le modèle au moment du test, l’équipe de recherche appelle cette nouvelle couche la couche Test-Time Training (TTT).

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

L'équipe de recherche présente deux exemples simples : TTT-Linear et TTT-MLP, où les états cachés sont respectivement des modèles linéaires et MLP à deux couches. Les couches TTT peuvent être intégrées dans n'importe quelle architecture réseau et optimisées de bout en bout, à l'instar des couches RNN et de l'auto-attention.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Afin de rendre la couche TTT plus efficace, l'étude a adopté quelques astuces pour améliorer la couche TTT :

Premièrement, comme si vous preniez une étape de gradient pour les séquences de mini-batchs lors d'un entraînement régulier afin d'obtenir un meilleur parallélisme, l'étude Utilisez de petits lots de jetons pendant TTT.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Deuxièmement, l'étude développe une double forme pour les opérations au sein de chaque mini-lot TTT afin de mieux utiliser les GPU et TPU modernes. Le résultat du formulaire double est équivalent à la simple implémentation, mais la formation est plus de 5 fois plus rapide. Comme le montre la figure 3, TTT-Linear est plus rapide que Transformer et comparable à Mamba dans le contexte 8k.

L'équipe de recherche estime que toutes les couches de modélisation de séquence peuvent être considérées comme stockant le contexte historique dans un état caché, comme le montre la figure 4.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Par exemple, les couches RNN telles que les couches LSTM, RWKV et Mamba compressent le contexte dans un état de taille fixe au fil du temps. Cette compression a deux conséquences : d'une part, mapper les jetons d'entrée x_t aux jetons de sortie z_t est efficace car les règles de mise à jour et les règles de sortie pour chaque jeton nécessitent un temps constant. D’un autre côté, les performances d’une couche RNN dans des contextes longs sont limitées par l’expressivité de ses états cachés s_t.

L'auto-attention peut également être vue du point de vue ci-dessus, sauf que son état caché (souvent appelé cache clé-valeur) est une liste qui croît linéairement avec t. Sa règle de mise à jour ajoute simplement le tuple KV actuel à cette liste, tandis que sa règle de sortie analyse tous les tuples avant t pour former la matrice d'attention. L'état caché stocke explicitement tout le contexte historique sans compression, ce qui rend l'attention personnelle plus expressive que les couches RNN pour les contextes longs. Cependant, le temps nécessaire pour analyser cet état caché à croissance linéaire augmente également de manière linéaire. Pour que les contextes longs restent efficaces et expressifs, les chercheurs ont besoin d’une meilleure heuristique de compression. Plus précisément, des milliers, voire des millions de jetons doivent être compressés dans un état caché qui capture efficacement leur structure et leurs relations sous-jacentes. Cela peut sembler difficile, mais de nombreuses personnes connaissent très bien cette heuristique.

Architecture de base. Le moyen le plus simple d'intégrer n'importe quelle couche RNN dans une architecture plus large consiste à remplacer directement l'auto-attention dans Transformer, appelée ici le backbone. Cependant, les RNN existants (tels que Mamba et Griffin) utilisent des couches de base différentes de Transformer. Plus particulièrement, leurs couches de base contiennent des convolutions temporelles avant la couche RNN, ce qui peut aider à collecter des informations locales au fil du temps. Après avoir expérimenté le squelette Mamba, les chercheurs ont découvert qu'il pouvait également améliorer la perplexité de la couche TTT. Il a donc été inclus dans la méthode proposée, comme le montre la figure 16.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Résultats expérimentaux

Dans l'expérience, les chercheurs ont comparé TTT-Linear et TTT-MLP avec Transformer et Mamba, deux lignes de base.

Texte court

De la figure 11 nous pouvons tirer les conclusions suivantes :

  • 2k, les performances de TTT-Linear (M), Mamba et Transformer sont comparables car des lignes se chevauchent pour la plupart. TTT-MLP (M) fonctionne légèrement moins bien avec un budget FLOP plus important. Bien que TTT-MLP présente une meilleure perplexité que TTT-Linear pour différentes tailles de modèles, le coût supplémentaire des FLOP compense cet avantage.
  • Pour le contexte 8k, TTT-Linear (M) et TTT-MLP (M) fonctionnent nettement mieux que Mamba, ce qui est assez différent de l'observation dans le contexte 2k. Même TTT-MLP (T) utilisant le réseau fédérateur Transformer est légèrement meilleur que Mamba à environ 1,3 milliard. Un phénomène important est qu'à mesure que la longueur du contexte augmente, les avantages de la couche TTT par rapport à la couche Mamba s'étendent également.
  • Avec une longueur de contexte atteignant 8k, Transformer fonctionne toujours bien dans la perplexité sous chaque taille de modèle, mais il n'est plus compétitif en raison du coût des FLOP.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Les résultats ci-dessus montrent l'impact du basculement de la couche TTT du réseau fédérateur Mamba vers le réseau fédérateur Transformer. Les chercheurs ont émis l’hypothèse que les convolutions temporelles dans le réseau fédérateur Mamba seraient plus utiles lorsque les états cachés de la couche de modélisation de séquence sont moins expressifs. Les modèles linéaires sont moins expressifs que les MLP et bénéficient donc davantage des convolutions.

Texte long : Livres

Pour évaluer la capacité des contextes longs, nous avons utilisé Books3, un sous-ensemble populaire de Pile, pour expérimenter des longueurs de contexte de 1 000 à 32 000 par incréments de 2x. La méthode de formation ici est la même que celle de Pile, et toutes les expériences pour la couche TTT sont effectuées en une seule fois. À partir du sous-ensemble de résultats de la figure 12, ils ont fait les observations suivantes :

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Dans le contexte de Books 2k, toutes les observations pour Pile 2k sont toujours valables, sauf que Mamba fonctionne désormais légèrement mieux que TTT-Linear (et leurs lignes se chevauchent à peu près dans la pile 2k).

Dans le contexte 32k, TTT-Linear (M) et TTT-MLP (M) fonctionnent mieux que Mamba, similaires aux observations pour Pile 8k. Même TTT-MLP (T) avec le backbone Transformer fonctionne légèrement mieux que Mamba dans un contexte 32k.

TTT-MLP (T) n'est que légèrement pire que TTT-MLP (M) à l'échelle 1,3B. Comme mentionné ci-dessus, il est difficile de dériver une loi d’échelle empirique en raison de l’absence d’un ajustement linéaire clair. Cependant, la forte tendance du TTT-MLP (T) suggère que le backbone Transformer pourrait être mieux adapté aux modèles plus grands et aux contextes plus longs, au-delà de la portée de notre évaluation.

Clock Time

La formation et l'inférence du LLM peuvent être décomposées en avant, en arrière et en génération. Le traitement des mots de repère pendant l'inférence (également appelé pré-remplissage) est le même que l'opération avant pendant l'entraînement, sauf que l'opération arrière ne nécessite pas le stockage de valeurs d'activation intermédiaires.

Étant donné que l'avant (pendant l'entraînement et l'inférence) et l'arrière peuvent être traités en parallèle, la forme double est utilisée ici. La génération de nouveaux jetons (également appelée décodage) est de nature séquentielle, c'est pourquoi la forme brute est utilisée ici.

Le chercheur a mentionné qu'en raison de ressources limitées, l'expérience présentée dans cet article a été écrite en JAX et exécutée sur TPU. Sur un pod TPU v5e-256, la ligne de base Transformer prend 0,30 seconde par itération pour s'entraîner avec 2 000 contextes, tandis que TTT-Linear prend 0,27 seconde par itération, ce qui est 10 % plus rapide sans aucune optimisation du système. Étant donné que Mamba (implémenté avec PyTorch, Triton et CUDA) ne peut fonctionner que sur GPU, afin de faire une comparaison équitable, les chercheurs ont procédé à une optimisation préliminaire du système de cette méthode afin qu'elle puisse fonctionner sur GPU.

Le côté gauche de la figure 15 montre la latence du noyau avant pour chaque modèle pour une taille de lot de 16. Tous les modèles sont 1,3B (Mamba est 1,4B). Il convient de noter que la ligne de base de Transformer ici est beaucoup plus rapide que celle de l'article Mamba car vLLM est utilisé ici à la place de HuggingFace Transformer.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

De plus, les chercheurs ont également écrit un autre noyau GPU pour la génération et ont comparé sa vitesse avec une taille de lot de 512 sur le côté droit de la figure 15. Une autre mesure de temps d’horloge murale couramment utilisée est le débit, qui prend en compte les avantages potentiels de l’utilisation de lots de plus grande taille. Pour le débit, toutes les observations ci-dessus et l'ordre entre les méthodes sont toujours valables.

Auteur principal

Après la soumission de l'étude TTT, l'un des auteurs de l'article, le professeur adjoint de l'UCSD Xiaolong Wang, a tweeté ses félicitations. Il a déclaré que la recherche sur le TTT a duré un an et demi, mais que cela fait en réalité cinq ans que l'idée du Test Time Training (TTT) est née. Bien que l’idée originale et les résultats actuels soient complètement différents.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Les trois principaux auteurs de l'article TTT sont respectivement originaires de Stanford, UC Berkeley et UCSD.

Parmi eux, Yu Sun est chercheur postdoctoral à l'Université de Stanford. Il est diplômé de l'UC Berkeley EECS avec un doctorat et sa direction de recherche à long terme est TTT.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Xinhao Li est doctorant à l'UCSD. Il est diplômé de l'Université des sciences et technologies électroniques de Chine.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Karan Dalal est doctorante à l'UC Berkeley et a cofondé une startup de télémédecine vétérinaire appelée Otto alors qu'elle était au lycée.

Changer complètement le modèle de langage : la nouvelle architecture TTT surpasse le Transformer, et le modèle ML remplace létat caché RNN

Les trois personnes ci-dessus ont toutes rédigé une formation de test en première ligne de leur site Web personnel présentant les orientations de la recherche.

Pour plus de détails sur la recherche, veuillez vous référer à l'article original.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn