Heim  >  Artikel  >  Technologie-Peripheriegeräte  >  Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

WBOY
WBOYOriginal
2024-07-17 16:08:17439Durchsuche

Die Leistung großer Modelle wurde von 125M auf 1,3B verbessert.


Unglaublich, das ist endlich passiert.

Eine neue LLM-Architektur (Large Language Model) soll Transformer ersetzen, das bisher im KI-Bereich beliebt war und eine bessere Leistung als Mamba aufweist. Am Montag wurde ein Artikel über Test-Time-Training (TTT) zu einem heißen Thema in der Community der künstlichen Intelligenz.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Link zum Papier: https://arxiv.org/abs/2407.04620

Die Autoren dieser Studie stammen von der Stanford University, der University of California, Berkeley, der University of California, San Diego und Meta. Sie entwarfen eine neue Architektur, TTT, die den verborgenen Zustand von RNN durch ein Modell für maschinelles Lernen ersetzte. Das Modell komprimiert den Kontext durch den tatsächlichen Gradientenabfall der Eingabe-Tokens.

Karan Dalal, einer der Autoren der Studie, sagte, er glaube, dass dies den Ansatz des Sprachmodells grundlegend verändern werde.
Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand
In Modellen für maschinelles Lernen ersetzt die TTT-Schicht direkt die Aufmerksamkeit und erschließt die lineare Komplexitätsarchitektur durch ausdrucksstarkes Gedächtnis, sodass wir LLM mit Millionen (manchmal Milliarden) Token im Kontext trainieren können.

Der Autor führte eine Reihe von Vergleichen an großen Modellen mit Parametergrößen von 125 M bis 1,3 B durch und stellte fest, dass sowohl TTT-Linear als auch TTT-MLP mit den leistungsstärksten Transformers- und Mamba-Architekturmethoden mithalten oder diese übertreffen können.

Als neuer Informationskomprimierungs- und Modellspeichermechanismus kann die TTT-Schicht die Selbstaufmerksamkeitsschicht in Transformer einfach und direkt ersetzen.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Im Vergleich zu Mamba hat TTT-Linear eine geringere Verwirrung, weniger FLOPs (links) und eine bessere Nutzung langer Kontexte (rechts):

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Dies ist nicht nur linear in der theoretischen Komplexität, sondern auch in der tatsächlichen Ausführung Die Zeit ist auch schneller.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

  • Nachdem das Papier online ging, machte der Autor den Code und Jax öffentlich, damit die Leute trainieren und testen können: https://github.com/test-time-training/ttt-lm-jax
  • Auch PyTorch-Inferenzcode: https://github.com/test-time-training/ttt-lm-pytorch Natur der RNN-Schichten: Im Gegensatz zum Selbstaufmerksamkeitsmechanismus muss die RNN-Schicht den Kontext in einen verborgenen Zustand fester Größe komprimieren, und die Aktualisierungsregeln müssen die zugrunde liegende Struktur und Beziehungen zwischen Tausenden oder sogar Millionen von Token ermitteln.

Das Forschungsteam beobachtete zunächst, dass selbstüberwachtes Lernen große Trainingssätze in Gewichtungen für Modelle wie LLM komprimieren kann und LLM-Modelle oft ein tiefes Verständnis der semantischen Verbindungen zwischen ihren Trainingsdaten aufweisen.
Inspiriert von dieser Beobachtung entwarf das Forschungsteam eine neue Klasse von Sequenzmodellierungsschichten, bei denen der verborgene Zustand ein Modell und die Aktualisierungsregel ein Schritt des selbstüberwachten Lernens ist. Da der Prozess der Aktualisierung des verborgenen Zustands in der Testsequenz dem Training des Modells zur Testzeit entspricht, nennt das Forschungsteam diese neue Schicht die Schicht „Test-Time Training“ (TTT).

Das Forschungsteam stellt zwei einfache Beispiele vor: TTT-Linear und TTT-MLP, wobei die verborgenen Zustände lineare Modelle bzw. zweischichtiges MLP sind. TTT-Schichten können in jede Netzwerkarchitektur integriert und Ende-zu-Ende optimiert werden, ähnlich wie RNN-Schichten und Selbstaufmerksamkeit.

Um die TTT-Schicht effizienter zu machen, wurden in der Studie einige Tricks zur Verbesserung der TTT-Schicht übernommen:

Erstens, ähnlich wie bei der Durchführung eines Gradientenschritts für Mini-Batch-Sequenzen während des regulären Trainings, um eine bessere Parallelität zu erzielen, die Studie Verwenden Sie während der TTT kleine Mengen an Token.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Zweitens entwickelt die Studie eine duale Form für Vorgänge innerhalb jedes TTT-Mini-Batches, um moderne GPUs und TPUs besser zu nutzen. Die Ausgabe der dualen Form entspricht der einfachen Implementierung, das Training ist jedoch mehr als fünfmal schneller. Wie in Abbildung 3 dargestellt, ist TTT-Linear schneller als Transformer und im 8k-Kontext mit Mamba vergleichbar.

Das Forschungsteam geht davon aus, dass alle Ebenen der Sequenzmodellierung als Speicherung historischer Kontexte in einem verborgenen Zustand betrachtet werden können, wie in Abbildung 4 dargestellt.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Zum Beispiel komprimieren RNN-Schichten wie LSTM-, RWKV- und Mamba-Schichten den Kontext im Laufe der Zeit in einen Zustand fester Größe. Diese Komprimierung hat zwei Konsequenzen: Einerseits ist die Zuordnung der Eingabetokens x_t zu den Ausgabetokens z_t effizient, da die Aktualisierungsregeln und Ausgaberegeln für jedes Token eine konstante Zeit erfordern. Andererseits wird die Leistung einer RNN-Schicht in langen Kontexten durch die Ausdruckskraft ihrer verborgenen Zustände s_t begrenzt.

Selbstaufmerksamkeit kann auch aus der obigen Perspektive betrachtet werden, mit der Ausnahme, dass ihr verborgener Zustand (oft als Schlüsselwert-Cache bezeichnet) eine Liste ist, die linear mit t wächst. Seine Aktualisierungsregel hängt einfach das aktuelle KV-Tupel an diese Liste an, während seine Ausgaberegel alle Tupel vor t durchsucht, um die Aufmerksamkeitsmatrix zu bilden. Der verborgene Zustand speichert explizit den gesamten historischen Kontext ohne Komprimierung, was die Selbstaufmerksamkeit für lange Kontexte ausdrucksvoller macht als RNN-Schichten. Allerdings wächst auch die Zeit, die zum Scannen dieses linear wachsenden verborgenen Zustands erforderlich ist, linear an. Um lange Kontexte effizient und ausdrucksstark zu halten, benötigen Forscher eine bessere Komprimierungsheuristik. Insbesondere müssen Tausende oder möglicherweise Millionen von Token in einen verborgenen Zustand komprimiert werden, der ihre zugrunde liegende Struktur und Beziehungen effektiv erfasst. Das hört sich vielleicht schwierig an, aber viele Menschen sind mit dieser Heuristik tatsächlich sehr vertraut.

Backbone-Architektur. Der sauberste Weg, eine RNN-Schicht in eine größere Architektur zu integrieren, besteht darin, die Selbstaufmerksamkeit in Transformer, hier als Backbone bezeichnet, direkt zu ersetzen. Bestehende RNNs (wie Mamba und Griffin) verwenden jedoch andere Backbone-Schichten als Transformer. Vor allem enthalten ihre Backbone-Schichten vor der RNN-Schicht zeitliche Faltungen, die dabei helfen können, lokale Informationen über die Zeit hinweg zu sammeln. Nach Experimenten mit dem Mamba-Rückgrat stellten die Forscher fest, dass es auch die Perplexität der TTT-Schicht verbessern konnte, weshalb es in die vorgeschlagene Methode einbezogen wurde, wie in Abbildung 16 dargestellt.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Experimentelle Ergebnisse

Im Experiment verglichen die Forscher TTT-Linear und TTT-MLP mit Transformer und Mamba, zwei Basislinien.

Kurztext

Aus Abbildung 11 können wir folgende Schlussfolgerungen ziehen:

  • 2k-Kontext, die Leistung von TTT-Linear (M), Mamba und Transformer ist vergleichbar, weil der Linien überlappen sich größtenteils. TTT-MLP (M) schneidet bei größerem FLOP-Budget etwas schlechter ab. Obwohl TTT-MLP bei verschiedenen Modellgrößen eine bessere Verwirrung als TTT-Linear aufweist, wird dieser Vorteil durch die zusätzlichen Kosten von FLOPs ausgeglichen.
  • Im 8k-Kontext schneiden sowohl TTT-Linear (M) als auch TTT-MLP (M) deutlich besser ab als Mamba, was sich deutlich von der Beobachtung im 2k-Kontext unterscheidet. Sogar TTT-MLP (T), das das Transformer-Backbone-Netzwerk verwendet, ist mit etwa 1,3 B etwas besser als Mamba. Ein wesentliches Phänomen besteht darin, dass mit zunehmender Kontextlänge auch die Vorteile der TTT-Schicht gegenüber der Mamba-Schicht zunehmen.
  • Mit einer Kontextlänge von 8 KB schneidet Transformer bei Perplexität unter jeder Modellgröße immer noch gut ab, ist jedoch aufgrund der Kosten für FLOPs nicht mehr wettbewerbsfähig.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Die obigen Ergebnisse zeigen die Auswirkungen des Wechsels der TTT-Schicht vom Mamba-Backbone-Netzwerk zum Transformer-Backbone-Netzwerk. Die Forscher stellten die Hypothese auf, dass zeitliche Faltungen im Mamba-Backbone-Netzwerk hilfreicher sind, wenn die verborgenen Zustände der Sequenzmodellierungsschicht weniger aussagekräftig sind. Lineare Modelle sind weniger ausdrucksstark als MLPs und profitieren daher stärker von Faltungen.

Langtext: Bücher

Um die Fähigkeit langer Kontexte zu bewerten, haben wir Books3, eine beliebte Teilmenge von Pile, verwendet, um mit Kontextlängen von 1.000 bis 32.000 in 2x-Schritten zu experimentieren. Die Trainingsmethode ist hier die gleiche wie bei Pile, und alle Experimente für die TTT-Schicht werden in einem Trainingslauf durchgeführt. Aus der Teilmenge der Ergebnisse in Abbildung 12 machten sie die folgenden Beobachtungen:

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Im Kontext von Books 2k gelten alle Beobachtungen für Pile 2k immer noch, mit der Ausnahme, dass Mamba jetzt etwas besser abschneidet als TTT-Linear (und ihre Linien überlappen sich ungefähr in Stapel 2k).

Im 32k-Kontext schneiden sowohl TTT-Linear (M) als auch TTT-MLP (M) besser ab als Mamba, ähnlich den Beobachtungen für Pile 8k. Selbst TTT-MLP (T) mit Transformer-Backbone schneidet im 32k-Kontext etwas besser ab als Mamba.

TTT-MLP (T) ist im Maßstab 1,3B nur geringfügig schlechter als TTT-MLP (M). Wie oben erwähnt, ist es aufgrund des Fehlens einer klaren linearen Anpassung schwierig, ein empirisches Skalierungsgesetz abzuleiten. Der starke Trend bei TTT-MLP (T) deutet jedoch darauf hin, dass das Transformer-Backbone möglicherweise besser für größere Modelle und längere Kontexte geeignet ist, was über den Rahmen unserer Bewertung hinausgeht.

Uhrzeit

Das Training und die Schlussfolgerung von LLM können in Vorwärts, Rückwärts und Generierung zerlegt werden. Die Cue-Wortverarbeitung während der Inferenz (auch Pre-Population genannt) ist die gleiche wie die Vorwärtsoperation während des Trainings, mit der Ausnahme, dass die Rückwärtsoperation keine Speicherung von Zwischenaktivierungswerten erfordert.

Da sowohl Vorwärts (beim Training und Inferenz) als auch Rückwärts parallel verarbeitet werden können, wird hier die duale Form verwendet. Die Generierung neuer Token (auch Dekodierung genannt) erfolgt sequentiell, daher wird hier die Rohform verwendet.

Der Forscher erwähnte, dass das Experiment in diesem Artikel aufgrund von Ressourcenbeschränkungen in JAX geschrieben wurde und auf TPU lief. Auf einem v5e-256-TPU-Pod benötigt die Transformer-Basislinie 0,30 Sekunden pro Iteration, um mit 2K-Kontexten zu trainieren, während TTT-Linear 0,27 Sekunden pro Iteration benötigt, was ohne Systemoptimierungen 10 % schneller ist. Da Mamba (implementiert mit PyTorch, Triton und CUDA) nur auf der GPU ausgeführt werden kann, führten die Forscher für einen fairen Vergleich eine vorläufige Systemoptimierung dieser Methode durch, damit sie auf der GPU ausgeführt werden kann.

Die linke Seite von Abbildung 15 zeigt die Latenz des Vorwärtskernels für jedes Modell bei einer Stapelgröße von 16. Alle Modelle sind 1.3B (Mamba ist 1.4B). Es ist erwähnenswert, dass die Transformer-Basislinie hier viel schneller ist als die im Mamba-Artikel, da hier vLLM anstelle von HuggingFace Transformer verwendet wird.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Darüber hinaus haben die Forscher auch einen weiteren GPU-Kernel zur Generierung geschrieben und dessen Geschwindigkeit mit einer Batchgröße von 512 auf der rechten Seite von Abbildung 15 verglichen. Eine weitere häufig verwendete Zeitmetrik ist der Durchsatz, der die potenziellen Vorteile der Verwendung größerer Chargengrößen berücksichtigt. Für den Durchsatz gelten weiterhin alle oben genannten Beobachtungen und die Reihenfolge zwischen den Methoden.

Hauptautor

Nachdem die TTT-Studie eingereicht wurde, twitterte einer der Autoren des Papiers, UCSD-Assistenzprofessor Xiaolong Wang, seine Glückwünsche. Er sagte, dass die Forschung zu TTT anderthalb Jahre gedauert habe, aber tatsächlich sei es fünf Jahre her, seit die Idee des Test Time Training (TTT) geboren wurde. Obwohl die ursprüngliche Idee und die aktuellen Ergebnisse völlig unterschiedlich sind.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Die drei Hauptautoren des TTT-Papiers kommen aus Stanford, UC Berkeley bzw. UCSD.

Unter ihnen ist Yu Sun Postdoktorand an der Stanford University. Er schloss sein Studium an der UC Berkeley EECS mit einem Ph.D. ab und seine langfristige Forschungsrichtung ist TTT.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Xinhao Li ist Doktorand an der UCSD. Er hat seinen Abschluss an der University of Electronic Science and Technology of China gemacht.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Karan Dalal ist Doktorand an der UC Berkeley und war während seiner Schulzeit Mitbegründer eines Veterinär-Telemedizin-Startups namens Otto.

Ändern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand

Die oben genannten drei Personen haben alle in der ersten Zeile ihrer persönlichen Websites eine Testzeitschulung geschrieben, in der sie Forschungsrichtungen vorstellten.

Weitere Forschungsdetails finden Sie im Originalpapier.

Das obige ist der detaillierte Inhalt vonÄndern Sie das Sprachmodell vollständig: Die neue TTT-Architektur übertrifft den Transformer und das ML-Modell ersetzt den verborgenen RNN-Zustand. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn