Heim >Technologie-Peripheriegeräte >KI >Ist Flash Attention stabil? Meta und Harvard stellten fest, dass die Gewichtsabweichungen ihrer Modelle um Größenordnungen schwankten
Meta FAIR hat sich mit Harvard zusammengetan, um einen neuen Forschungsrahmen zur Optimierung von Datenverzerrungen bereitzustellen, die durch groß angelegtes maschinelles Lernen verursacht werden.
Wie wir alle wissen, dauert das Training großer Sprachmodelle oft Monate und nutzt Hunderte oder sogar Tausende von GPUs. Am Beispiel des Modells LLaMA2 70B erfordert dessen Training insgesamt 1.720.320 GPU-Stunden. Das Training großer Modelle stellt aufgrund des Umfangs und der Komplexität dieser Arbeitsbelastungen einzigartige systemische Herausforderungen dar.
In letzter Zeit haben viele Institutionen über Instabilität während des Trainingsprozesses beim Training generativer SOTA-KI-Modelle berichtet. Diese treten normalerweise in Form von Verlustspitzen auf, wie beispielsweise beim PaLM-Modell von Google, die während des Trainingsprozesses bis zu 20 Mal auftraten Spikes.
Numerische Abweichung ist die Hauptursache für diese Trainingsungenauigkeit. Aufgrund der extrem hohen Ausführungskosten beim Training großer Sprachmodelle ist die Quantifizierung numerischer Abweichungen zu einem zentralen Thema geworden.
In der neuesten Arbeit haben Forscher von Meta und der Harvard University eine prinzipielle quantitative Methode entwickelt, um numerische Verzerrungen bei der Trainingsoptimierung zu verstehen. Dies wird verwendet, um verschiedene hochmoderne Optimierungstechniken zu bewerten und festzustellen, ob sie beim Training großer Modelle zu unerwarteten Instabilitäten führen könnten. Die Forscher stellten fest, dass bestehende Optimierungsmethoden zwar bei einigen Aufgaben gut funktionierten, bei der Anwendung auf große Modelle jedoch einige numerische Abweichungen auftraten. Diese numerische Verzerrung kann während des Trainingsprozesses zu Instabilität führen und zu einer Verschlechterung der Modellleistung führen. Um dieses Problem zu lösen, schlugen die Forscher eine Optimierung vor, die auf prinzipiellen quantitativen Methoden basiert.
Link zum Papier: https://arxiv.org/pdf/2405.02803Es wurde festgestellt, dass die numerische Abweichung von Flash Attention bei einem einzelnen Vorwärtsdurchlauf eine Größenordnung größer war als die Baseline Attention von BF16.
Entwicklung eines Mikro-Benchmarks zur Störung der numerischen Genauigkeit in einer bestimmten Optimierung;
Auswertung der numerischen Werte durch datengesteuerte Analyse basierend auf der Wasserstein-Distanz Wie sich Bias in Änderungen der Modellgewichte niederschlägt.
Forscher haben die folgenden zwei Hauptbeiträge zur Quantifizierung numerischer Abweichungen geleistet:
(1) Entwarf einen Mikro-Benchmark, um den Einfluss der numerischen Genauigkeit auf numerische Abweichungen zu isolieren.
Der von den Forschern entwickelte Mikro-Benchmark ist eine Technik zur Messung und Quantifizierung der numerischen Abweichung, die durch herkömmliche Black-Box-Optimierung (wie Flash Attention) verursacht wird. Durch die Störung von Aspekten, die normalerweise in den bereitgestellten Kerneln nicht verfügbar sind, führten sie zu der Entdeckung, dass Flash Attention bei niedriger numerischer Präzision (BF16) im Vergleich zu Baseline Attention eine um etwa eine Größenordnung höhere numerische Abweichung aufweist.
(2) Durchführung einer datengesteuerten Analyse basierend auf der Wasserstein-Distanzmetrik.
Diese Studie unterstreicht die Bedeutung der Entwicklung eines prinzipiellen Ansatzes, um „die Auswirkungen der Trainingsoptimierung auf numerische Verzerrungen nicht nur zu quantifizieren, sondern auch zu kontextualisieren“, indem Proxys erstellt werden, um den numerischen Verzerrungskontext zu kontextualisieren, mit dem Ziel, die Wahrscheinlichkeit nachgelagerter Modelleffekte abzuleiten (d. h. , Trainingsinstabilitäten), die oft schwer zu messen sind.
Die Forscher entwickelten zunächst einen Mikro-Benchmark, um die durch Flash Attention verursachte numerische Abweichung zu isolieren und zu untersuchen. Wie in Abbildung 2 dargestellt, implementierten sie Flash Attention numerisch neu, um unterschiedliche numerische Genauigkeiten zu analysieren und potenzielle Optimierungsmaßnahmen bei jedem Schritt des Algorithmus anzuwenden.
Abbildung 2: Zusammenfassung des Mikrobenchmark-Designs.
Dies ist notwendig, da der Flash Attention-Kern derzeit nur die Zahlenformate FP16 und BF16 unterstützt. Dieser Kernel ist auch ein Wrapper-API-Aufruf für CUDA-Code, was es schwierig macht, den Algorithmus zu stören, um die Auswirkungen numerischer Verzerrungen zu untersuchen.
Im Gegensatz dazu ermöglicht ihr Mikro-Benchmark-Design eine präzise Eingabe und Änderung innerhalb des Algorithmus. Die Forscher verifizierten den Mikrobenchmark mit dem ursprünglichen Flash Attention-Kernel.
Sie haben außerdem eine Technik entwickelt, um die Ausgabe der Aufmerksamkeitsmatrix bei jedem Schritt während der Modellausführung zu vergleichen. Und der Modellcode wurde geändert, um bei jedem Aufmerksamkeitsaufruf die Baseline-Aufmerksamkeit und die Flash-Aufmerksamkeit zu berechnen, was einen genauen Vergleich der Ausgabematrix für dieselbe Eingabematrix ermöglicht.
Um dies in einen Zusammenhang zu bringen, haben wir auch die Metriken „Maximale Differenz“ und „Wasserstein-Distanz“ verwendet, um den Unterschied in den Modellgewichten während des Trainings mithilfe identischer und unabhängiger Trainingsläufe zu quantifizieren.
Für das Trainingsexperiment verwendeten die Forscher einen generativen KI-Workload (d. h. ein Text-zu-Bild-Modell), der Texteingaben in Bilder umwandelt. Sie trainierten das Modell mithilfe des Shutterstock-Datensatzes neu und führten das Experiment auf einem Cluster von NVIDIA 80 GB A100-GPUs durch.
Die Forscher analysierten zunächst die Auswirkungen von Flash Attention im Vorwärtspassprozess. Mithilfe von Mikrobenchmarks untersuchten sie den Einfluss unterschiedlicher numerischer Genauigkeiten auf die von Attention berechnete Ausgabematrix unter der Bedingung, dass die zufällig initialisierten Abfrage-, Schlüssel- und Wertevektoren gleich waren.
Wie in Abbildung 3 dargestellt, nimmt die numerische Abweichung zwischen Flash Attention und Baseline Attention ab, wenn die Anzahl der Mantissenstellen zunimmt, wenn Forscher unterschiedliche numerische Formate von BF16 bis FP64 verwenden. Dies deutet darauf hin, dass der numerische Unterschied auf die Näherung zurückzuführen ist, die mit weniger Mantissenstellen einhergeht.
Abbildung 3: Die Auswirkung des numerischen Formats auf die numerische Abweichung von Flash Attention.
Danach legte der Forscher einen „goldenen Wert“ für die Grundaufmerksamkeit im numerischen FP64-Format für den Standardvergleich fest und verglich dann die Aufmerksamkeitsausgabe in verschiedenen numerischen Formaten mit diesem Wert (wie in Abbildung 4 dargestellt).
Abbildung 4: Vergleich des „Goldwerts“ der Baseline Attention unter FP64.
Die Ergebnisse zeigen, dass die numerische Abweichung von Flash Attention etwa zehnmal so hoch ist wie die von Baseline unter BF16.
Um diese beobachtete numerische Abweichung weiter zu analysieren, scannten die Forscher die Sequenzlänge der Matrix, während sie die Kachelgröße und die SRAM-Größe konstant hielten (wie in Abbildung 5 dargestellt).
Abbildung 5: Die Auswirkung der Sequenzlänge auf die numerische Abweichung von Flash Attention.
Wie in der Abbildung gezeigt, ergibt sich mit zunehmender Sequenzlänge, unabhängig davon, ob sie anhand (a) der Obergrenze der maximalen Differenz oder (b) dem Mittelwert und der Standardabweichung der Differenz gemessen wird, die Differenz zwischen Flash Attention und Baseline Achtung Zahlenabweichungen nehmen zu.
Darüber hinaus verwenden Forscher auch Mikro-Benchmark-Designs, um Experimente mit verschiedenen Optimierungen durchzuführen, um die Auswirkungen numerischer Abweichungen besser zu verstehen (wie in Abbildung 6 dargestellt).
Abbildung 6a zeigt, wie das Vertauschen der Reihenfolge der Blockabmessungen dazu führt, dass der numerische Unterschied zwischen Flash Attention und Baseline Attention zunimmt. Andere Störungen in Abbildung 6b, wie z. B. die Beschränkung der Kachelgröße auf Quadrate, haben keinen Einfluss auf die numerische Verzerrung. Abbildung 6c zeigt, dass die numerische Abweichung umso geringer ist, je größer die Block-/Kachelgröße ist.
Abbildung 6: Algorithmusänderungen und ihre Auswirkung auf die beobachteten numerischen Abweichungen.
Während Flash Attention während des Vorwärtsdurchlaufs zu numerischer Verzerrung in der Aufmerksamkeitsausgabe führen kann, besteht das ultimative Ziel dieser Studie darin, festzustellen, ob dies während des Modelltrainings zu irgendwelchen Auswirkungen führt, um zu untersuchen, ob es trägt zur Trainingsinstabilität bei.
Daher hoffen die Forscher zu quantifizieren, ob Flash Attention das Modell während des Trainings verändert, d. h. ob sich der oben beobachtete Unterschied in der Aufmerksamkeitsausgabe in den aktualisierten Modellgewichten während des Trainings widerspiegelt.
Die Forscher verwendeten zwei Indikatoren, um den Unterschied in der Modellgewichtung zwischen Modellen, die mit Baseline Attention trainiert wurden, und Modellen, die mit Flash Attention trainiert wurden, zu messen. Zuerst wird die maximale Differenz berechnet, d. h. der Absolutwert der Differenz zwischen den Gewichtsmatrizen ermittelt und der Maximalwert genommen, um so die Obergrenze der Abweichung wie folgt zu erhalten:
Während die maximale Differenz bereitgestellt wird eine Obergrenze der numerischen Abweichung, berücksichtigt jedoch nicht die Verteilung jeder Matrix. Daher quantifizieren Forscher Gewichtsunterschiede anhand der Wasserstein-Distanz, die ein gängiges Maß für die Ähnlichkeit zwischen Tensoren ist. Obwohl die Berechnung etwas komplexer ist, umfasst die Wasserstein-Distanz Forminformationen der Tensorverteilung, um die Ähnlichkeit zu messen. Die Berechnungsformel lässt sich wie folgt zusammenfassen:
Je niedriger der Wert, desto höher die Ähnlichkeit zwischen den Matrizen.
Anhand dieser beiden Metriken quantifizierten die Forscher dann, wie sich die Modellgewichte der Flash-Aufmerksamkeit im Vergleich zur Basisaufmerksamkeit während des gesamten Trainingsprozesses veränderten:
Gemäß Wasserstein-Distanz und maximalem Unterschied dies für zwei Indikatoren während Während des gesamten Trainingsprozesses ändert sich durch das Hinzufügen von Flash Attention das Modellgewicht, und mit fortschreitendem Training wird dieser Unterschied immer größer. Dies zeigt, dass sich das mit Flash Attention trainierte Modell von dem mit Baseline Attention trainierten Modell unterscheidet. Das gleiche trainierte Modell konvergierte zu einem anderen Modell.
Training ist jedoch ein stochastischer Prozess, und bestimmte Änderungen in der Modellstruktur können zu ähnlichen Ergebnissen hinsichtlich nachgelagerter Effekte und Genauigkeit führen. Dies ist auch dann bemerkenswert, wenn die Gewichte der mit Flash Attention und Baseline Attention trainierten Modelle unterschiedlich sind.
Das vollständige Training eines Modells und die Bewertung der Genauigkeit ist eine kostspielige und ressourcenintensive Aufgabe, insbesondere bei großen Modellen, deren Training Monate dauert.
Der Forscher hat einen Proxy konfiguriert, um Folgendes zu untersuchen:
(a) Wie signifikant sind diese Gewichtsveränderungen?
(b) Kann dies mit Standardgewichtsveränderungen in anderen weit verbreiteten Trainingsoptimierungen zusammenhängen?
Um dieses Ziel zu erreichen, haben die Forscher eine Reihe von Experimenten entworfen, um zu vergleichen, wie sich der Gewichtsunterschied während des Trainingsprozesses in verschiedenen Szenarien verändert.
Zusätzlich zum Vergleich des Trainingsprozesses mit Flash Attention und Baseline Attention quantifizierten sie auch den Gewichtsunterschied während desselben Trainingsprozesses, bei dem die Gewichte zu Beginn des Trainings auf unterschiedliche Zufallswerte initialisiert wurden. Dies stellt eine Grenze dar, da die zufällige Gewichtungsinitialisierung eine gängige Technik ist und häufig zu gleichwertigen Ergebnissen führt.
Darüber hinaus haben die Forscher auch Veränderungen in den Modellgewichten gemessen, die mit unterschiedlichen Genauigkeiten trainiert wurden. Numerische Präzision (d. h. FP16 vs. FP32) kann zu nachgelagerten Änderungen führen, die als Obergrenze für die Bedeutung von Flash Attention-Gewichten dienen.
Wie in Abbildung 8 gezeigt, kann festgestellt werden, dass die Änderungsrate der Modellgewichtungsverzerrung mithilfe von Flash Attention mit der Änderungsrate der Gewichtsverzerrung verschiedener Modellinitialisierungen vergleichbar oder kleiner ist (beachten Sie die Steigung der roten und blauen Kurven). .
Darüber hinaus ist die Gewichtsänderungsrate bei Verwendung von FP16 und FP32 höher und die Änderung ist größer als bei der Initialisierung verschiedener Modelle.
Diese Ergebnisse liefern einen Anhaltspunkt und zeigen: „Obwohl Flash Attention eine numerische Verzerrung aufweist, wird diese durch zufällige Modellinitialisierung und Training mit geringer Präzision begrenzt. Und die eingeführte Modellgewichtungsverzerrung beträgt etwa 10 %, wenn mit geringer Präzision trainiert wird.“ . 1/2 bis 1/5 Mal.
Weitere Forschungsdetails finden Sie im Originalpapier.
Das obige ist der detaillierte Inhalt vonIst Flash Attention stabil? Meta und Harvard stellten fest, dass die Gewichtsabweichungen ihrer Modelle um Größenordnungen schwankten. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!