Meta FAIR 聯合哈佛優化大規模機器學習時產生的資料偏差,提供了新的研究架構。
據所周知,大語言模型的訓練常常需要數月的時間,使用數百甚至上千個GPU。以LLaMA2 70B模型為例,其訓練總共需要1,720,320個GPU小時。由於這些工作負載的規模和複雜性,導致訓練大模型存在著獨特的系統性挑戰。
最近,許多機構在訓練SOTA生成式AI模型時報告了訓練過程中的不穩定情況,它們通常以損失尖峰的形式出現,例如Google的PaLM模型訓練過程中出現了多達20次的損失尖峰。
數值偏差是造成這種訓練不準確性的根因,由於大語言模型訓練執行成本極高,如何量化數值偏差儼然成為關鍵問題。
在最新的一項工作中,來自 Meta、哈佛大學的研究者開發了一個原則性定量方法來理解訓練優化中的數值偏差。以此評估不同的最新最佳化技術,並確定它們在用於訓練大模型時是否可能引入意外的不穩定性。 研究者發現,儘管現有的最佳化方法在一些任務上表現出色,但在大型模型上應用時,會出現一些數值偏差。這種數值偏差可能會在訓練過程中產生不穩定性,導致模型的表現下降。 為了解決這個問題,研究者提出了一種基於原則性量化方法的最佳化
- 論文標題:Is Flash Attention Stable?
- 論文連結:https://arxiv.org/pdf/2405.02803
結果發現,在單獨的前向傳遞過程中,Flash Attention 的數值偏差比BF16 的Baseline Attention 大一個數量級。
具體而言,該方法包括兩個階段,包括:
- #開發一個微基準來擾動給定最佳化中的數值精度;
- 透過基於Wasserstein 距離的資料驅動分析評估數值偏差如何轉換為模型權重的變化。
研究者分析了 SOTA 最佳化技術 Flash Attention,並量化了可可能引入的數值偏差。 Flash Attention 是一種廣泛用於加速注意力機制的技術,通常被認為是 Transformer 模型中的系統瓶頸。 Flash Attention 在提高速度和減少記憶體存取量的同時,也依賴演算法最佳化,而演算法最佳化有可能導致數值偏差的增加。
研究者假設添加重新縮放因子(rescaling factors )可能會引入無意的近似,導致數值折衷,這可能會在後續影響訓練穩定性。
他們在多模態文字到影像工作負載的背景下分析了 Flash Attention,以確定 Flash Attention 與其基準之間數值偏差的潛在重要性。最終,他們引入了一個框架來量化訓練優化的數值偏差及其下游影響。
研究者在數值偏差量化上主要做出了以下兩點貢獻:
(1)設計了一個微基準來分離數值精度對數值偏差的影響。
研究者所設計的微基準作為一種技術,用於衡量和量化傳統黑盒最佳化(如 Flash Attention)所導致的數值偏差。透過擾動通常在提供的內核中不可用的方面,他們開創性地發現在低數值精度(BF16)下,與 Baseline Attention 相比,Flash Attention 的數值偏差大約高出一個數量級。
(2)基於 Wasserstein Distance 度量進行了資料驅動的分析。
透過此分析,研究者將觀察到的數值偏差置於上下文,並為其對下游模型屬性的影響形成一個上限(upper bound)。在研究者的案例研究中,他們能夠限制觀察到的數值偏差的影響,並發現:「Flash Attention 引入的模型權重偏差大約為低精度訓練的1/2 至1/5 倍。」
這項研究強調了開發一種原則性方法的重要性:「不僅要量化,而且要將訓練優化對數值偏差的影響置於上下文中。」透過建立代理(proxies)來將數值偏差置於上下文中,旨在推斷通常難以衡量的下游模型效果(即訓練不穩定性)的可能性。
實驗方法
研究者首先發展了一個微基準來分離並研究 Flash Attention 所造成的數值偏差。如圖 2 所示,他們透過對 Flash Attention 進行數值上的重新實現,以分析不同的數值精度,並在演算法的每個步驟應用潛在的最佳化措施。
圖 2: 微基準設計摘要。
這是必要的,因為 Flash Attention 核心目前僅支援 FP16 和 BF16 數值格式。該核心也是 CUDA 程式碼的包裝 API 調用,這使得擾動演算法以檢查數值偏差的影響變得具有挑戰性。
相比之下,他們的微基準設計允許在演算法內部進行精度輸入和修改。研究者將微基準與原始的 Flash Attention kernel 進行了驗證。
他們進一步設計了一種技術,以比較模型執行過程中每個步驟的 Attention 矩陣的輸出。並修改了模型程式碼,每次呼叫注意力時都計算 Baseline Attention 和 Flash Attention,這允許對相同的輸入矩陣進行精確的輸出矩陣比較。
為了將其置於上下文中,研究者也透過相同和獨立的訓練運行,使用 Max difference 和 Wasserstein Distance 度量來量化模型權重在整個訓練過程中的差異。
對於訓練實驗,研究者則使用一種將文字輸入轉換為圖像的生成式 AI workload(即文字到圖像模型)。他們使用 Shutterstock 資料集重新訓練模型,並在一組英偉達 80GB A100 GPU 叢集上執行此實驗。
透過微基準量化數值偏差
#研究者首先分析了 Flash Attention 在前向傳遞過程中的影響。他們利用微基準測試,在隨機初始化查詢、鍵、值向量相同的情況下,檢驗不同數值精確度對 Attention 計算的輸出矩陣的影響。
如圖3 所示,當研究者使用從BF16 到FP64 變化的不同數值格式時,Flash Attention 和Baseline Attention 之間的數值偏差隨著尾數位數的增加而減小。這表明數值差異是由於較少的尾數位數所固有的近似造成的。
圖 3:數值格式對於 Flash Attention 的數值偏差所產生的效果。
之後,研究者為進行標準比較,在FP64 數值格式下的Baseline Attention 設定了「黃金值」,然後將不同數值格式下的Attention 輸出與該值進行了比較(如圖4 所示)。
圖 4:FP64 下 Baseline Attention「黃金值」的比較。
結果表明,Flash Attention 的數值偏差大約是 BF16 下 Baseline 的 10 倍。
為了進一步分析這種觀察到的數值偏差,研究者保持 tile 大小和 SRAM 大小不變的同時,掃描了矩陣的序列長度(如圖 5 所示)。
圖 5: 序列長度對 Flash Attention 數值偏差的影響。
如圖所示,隨著序列長度的增加,無論是透過(a)最大差異上限的測量,或是透過(b)差異的平均值和標準差的測量,Flash Attention和Baseline Attention 之間的數值偏差都在增加。
除此之外,研究者也利用微基準設計進行不同最佳化的實驗,以便更了解數值偏差的影響(如圖 6 所示)。
圖 6a 顯示了調換 block 維數的順序如何導致 Flash Attention 和 Baseline Attention 之間的數值差異增大。圖 6b 中的其他擾動,例如限制 tile 大小為正方形,不會對數值偏差產生影響。圖 6c 顯示了 block/tile 大小越大,數值偏差越小。
圖 6: 演算法的改變及其對觀察到的數值偏差的影響。
透過權重差異來了解數值偏差
雖然在前向傳遞過程中,Flash Attention 可能會導致Attention 輸出的數值偏差,但這項研究的最終目標是確定這是否會在模型訓練過程中產生任何影響,以研究它是否會導致訓練的不穩定性。
因此,研究者希望量化 Flash Attention 是否在訓練過程中改變了模型,即上文觀察到的 Attention 輸出差異是否反映在訓練過程中更新的模型權重中。
研究者利用兩個指標來測量使用 Baseline Attention 訓練的模型與使用 Flash Attention 訓練的模型之間的模型權重差異。首先計算最大差異,即找出權重矩陣之間差異的絕對值並取最大值,從而得出偏差的上限,如下所示:
雖然最大差值提供了數值偏差的上限,但它沒有考慮到每個矩陣的分佈。因此,研究者透過 Wasserstein Distance 來量化權重差異,這是衡量張量之間相似性的常用度量。雖然在計算上稍微複雜,但 Wasserstein Distance 包含了張量分佈的形狀資訊以衡量相似性。計算公式概述如下:
數值越低,表示矩陣之間的相似度越高。
利用這兩個指標,研究者隨後量化了在整個訓練過程中與Baseline Attention 相比,Flash Attention 的模型權重是如何變化的:
#根據Wasserstein Distance 和Max Difference 這兩個指標,在整個訓練過程中,Flash Attention 的加入確實改變了模型權重,而且隨著訓練的繼續,這種差異只會越來越大,這顯示了使用Flash Attention 訓練的模型與使用Baseline Attention 訓練的相同模型收斂到了不同的模型。
然而,訓練是一個隨機過程,某些模型結構的改變可能會在下游效應和準確性方面產生相似的結果。即使使用 Flash Attention 和 Baseline Attention 訓練的模型權重不同,這也是值得關注的。
完全訓練模型並評估準確性是一項昂貴且資源密集的任務,特別是對於訓練需要數月的大模型來說。
研究者透過配置一個 proxy 來探索:
(a) 這些權重變化的意義有多大?
(b) 能否將其與其他廣泛採用的訓練優化中的標準權重變化聯繫起來?
為了實現這一目標,研究者設計了一系列實驗來比較在不同場景下,訓練過程中的權重差異是如何變化的。
除了對比使用 Flash Attention 和 Baseline Attention 的訓練過程外,他們還量化了在訓練開始時權重被初始化為不同隨機值的相同訓練過程中的權重差異。這提供了一個界限,因為隨機權重初始化是一種常用的技術,並且通常會產生等效的結果。
此外,研究者也測量了使用不同精度訓練的模型權重的變化。數值精確度(即 FP16 與 FP32)有可能導致下游變化,這作為確定了 Flash Attention 權重重要性的一個上限。
如圖8 所示,可以發現,使用Flash Attention 的模型權重偏差變化率與不同模型初始化的權重偏差變化率相當或更小(注意紅色和藍色曲線的斜率)。
此外,使用 FP16 與 FP32 時的權重變化率比不同模型初始化時的權重變化率更高,變化也更大。
這些結果提供了一個proxy,並表明:「雖然Flash Attention 會出現數值偏差,但它會被隨機模型初始化和低精度訓練所限制。而且所引入的模型權重偏差大約是低精度訓練時的1/2 至1/5 倍。相對權重差異。
更多研究細節,可參考原文。
以上是Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動的詳細內容。更多資訊請關注PHP中文網其他相關文章!

斯坦福大學以人為本人工智能研究所發布的《2025年人工智能指數報告》對正在進行的人工智能革命進行了很好的概述。讓我們用四個簡單的概念來解讀它:認知(了解正在發生的事情)、欣賞(看到好處)、接納(面對挑戰)和責任(弄清我們的責任)。 認知:人工智能無處不在,並且發展迅速 我們需要敏銳地意識到人工智能發展和傳播的速度有多快。人工智能係統正在不斷改進,在數學和復雜思維測試中取得了優異的成績,而就在一年前,它們還在這些測試中慘敗。想像一下,人工智能解決複雜的編碼問題或研究生水平的科學問題——自2023年

Meta的Llama 3.2:多模式和移動AI的飛躍 Meta最近公佈了Llama 3.2,這是AI的重大進步,具有強大的視覺功能和針對移動設備優化的輕量級文本模型。 以成功為基礎

本週的AI景觀:進步,道德考慮和監管辯論的旋風。 OpenAI,Google,Meta和Microsoft等主要參與者已經釋放了一系列更新,從開創性的新車型到LE的關鍵轉變

連接的舒適幻想:我們在與AI的關係中真的在蓬勃發展嗎? 這個問題挑戰了麻省理工學院媒體實驗室“用AI(AHA)”研討會的樂觀語氣。事件展示了加油

介紹 想像一下,您是科學家或工程師解決複雜問題 - 微分方程,優化挑戰或傅立葉分析。 Python的易用性和圖形功能很有吸引力,但是這些任務需要強大的工具

Meta's Llama 3.2:多式聯運AI強力 Meta的最新多模式模型Llama 3.2代表了AI的重大進步,具有增強的語言理解力,提高的準確性和出色的文本生成能力。 它的能力t

數據質量保證:與Dagster自動檢查和良好期望 保持高數據質量對於數據驅動的業務至關重要。 隨著數據量和源的增加,手動質量控制變得效率低下,容易出現錯誤。

大型機:AI革命的無名英雄 雖然服務器在通用應用程序上表現出色並處理多個客戶端,但大型機是專為關鍵任務任務而建立的。 這些功能強大的系統經常在Heavil中找到


熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

AI Hentai Generator
免費產生 AI 無盡。

熱門文章

熱工具

mPDF
mPDF是一個PHP庫,可以從UTF-8編碼的HTML產生PDF檔案。原作者Ian Back編寫mPDF以從他的網站上「即時」輸出PDF文件,並處理不同的語言。與原始腳本如HTML2FPDF相比,它的速度較慢,並且在使用Unicode字體時產生的檔案較大,但支援CSS樣式等,並進行了大量增強。支援幾乎所有語言,包括RTL(阿拉伯語和希伯來語)和CJK(中日韓)。支援嵌套的區塊級元素(如P、DIV),

DVWA
Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中

SecLists
SecLists是最終安全測試人員的伙伴。它是一個包含各種類型清單的集合,這些清單在安全評估過程中經常使用,而且都在一個地方。 SecLists透過方便地提供安全測試人員可能需要的所有列表,幫助提高安全測試的效率和生產力。清單類型包括使用者名稱、密碼、URL、模糊測試有效載荷、敏感資料模式、Web shell等等。測試人員只需將此儲存庫拉到新的測試機上,他就可以存取所需的每種類型的清單。

記事本++7.3.1
好用且免費的程式碼編輯器

MinGW - Minimalist GNU for Windows
這個專案正在遷移到osdn.net/projects/mingw的過程中,你可以繼續在那裡關注我們。 MinGW:GNU編譯器集合(GCC)的本機Windows移植版本,可自由分發的導入函式庫和用於建置本機Windows應用程式的頭檔;包括對MSVC執行時間的擴展,以支援C99功能。 MinGW的所有軟體都可以在64位元Windows平台上運作。