首頁 >科技週邊 >人工智慧 >Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動

Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動

WBOY
WBOY原創
2024-05-30 13:24:53777瀏覽

Meta FAIR 聯合哈佛優化大規模機器學習時產生的資料偏差,提供了新的研究架構。

據所周知,大語言模型的訓練常常需要數月的時間,使用數百甚至上千個GPU。以LLaMA2 70B模型為例,其訓練總共需要1,720,320個GPU小時。由於這些工作負載的規模和複雜性,導致訓練大模型存在著獨特的系統性挑戰。

最近,許多機構在訓練SOTA生成式AI模型時報告了訓練過程中的不穩定情況,它們通常以損失尖峰的形式出現,例如Google的PaLM模型訓練過程中出現了多達20次的損失尖峰。

數值偏差是造成這種訓練不準確性的根因,由於大語言模型訓練執行成本極高,如何量化數值偏差儼然成為關鍵問題。

在最新的一項工作中,來自 Meta、哈佛大學的研究者開發了一個原則性定量方法來理解訓練優化中的數值偏差。以此評估不同的最新最佳化技術,並確定它們在用於訓練大模型時是否可能引入意外的不穩定性。 研究者發現,儘管現有的最佳化方法在一些任務上表現出色,但在大型模型上應用時,會出現一些數值偏差。這種數值偏差可能會在訓練過程中產生不穩定性,導致模型的表現下降。 為了解決這個問題,研究者提出了一種基於原則性量化方法的最佳化

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动


  • 論文標題:Is Flash Attention Stable?
  • 論文連結:https://arxiv.org/pdf/2405.02803

結果發現,在單獨的前向傳遞過程中,Flash Attention 的數值偏差比BF16 的Baseline Attention 大一個數量級。

具體而言,該方法包括兩個階段,包括:

  • #開發一個微基準來擾動給定最佳化中的數值精度;
  • 透過基於Wasserstein 距離的資料驅動分析評估數值偏差如何轉換為模型權重的變化。

研究者分析了 SOTA 最佳化技術 Flash Attention,並量化了可可能引入的數值偏差。 Flash Attention 是一種廣泛用於加速注意力機制的技術,通常被認為是 Transformer 模型中的系統瓶頸。 Flash Attention 在提高速度和減少記憶體存取量的同時,也依賴演算法最佳化,而演算法最佳化有可能導致數值偏差的增加。

研究者假設添加重新縮放因子(rescaling factors )可能會引入無意的近似,導致數值折衷,這可能會在後續影響訓練穩定性。

他們在多模態文字到影像工作負載的背景下分析了 Flash Attention,以確定 Flash Attention 與其基準之間數值偏差的潛在重要性。最終,他們引入了一個框架來量化訓練優化的數值偏差及其下游影響。

研究者在數值偏差量化上主要做出了以下兩點貢獻:

(1)設計了一個微基準來分離數值精度對數值偏差的影響。

研究者所設計的微基準作為一種技術,用於衡量和量化傳統黑盒最佳化(如 Flash Attention)所導致的數值偏差。透過擾動通常在提供的內核中不可用的方面,他們開創性地發現在低數值精度(BF16)下,與 Baseline Attention 相比,Flash Attention 的數值偏差大約高出一個數量級。

(2)基於 Wasserstein Distance 度量進行了資料驅動的分析。

透過此分析,研究者將觀察到的數值偏差置於上下文,並為其對下游模型屬性的影響形成一個上限(upper bound)。在研究者的案例研究中,他們能夠限制觀察到的數值偏差的影響,並發現:「Flash Attention 引入的模型權重偏差大約為低精度訓練的1/2 至1/5 倍。」

這項研究強調了開發一種原則性方法的重要性:「不僅要量化,而且要將訓練優化對數值偏差的影響置於上下文中。」透過建立代理(proxies)來將數值偏差置於上下文中,旨在推斷通常難以衡量的下游模型效果(即訓練不穩定性)的可能性。

實驗方法

研究者首先發展了一個微基準來分離並研究 Flash Attention 所造成的數值偏差。如圖 2 所示,他們透過對 Flash Attention 進行數值上的重新實現,以分析不同的數值精度,並在演算法的每個步驟應用潛在的最佳化措施。

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

圖 2: 微基準設計摘要。

這是必要的,因為 Flash Attention 核心目前僅支援 FP16 和 BF16 數值格式。該核心也是 CUDA 程式碼的包裝 API 調用,這使得擾動演算法以檢查數值偏差的影響變得具有挑戰性。

相比之下,他們的微基準設計允許在演算法內部進行精度輸入和修改。研究者將微基準與原始的 Flash Attention kernel 進行了驗證。

他們進一步設計了一種技術,以比較模型執行過程中每個步驟的 Attention 矩陣的輸出。並修改了模型程式碼,每次呼叫注意力時都計算 Baseline Attention 和 Flash Attention,這允許對相同的輸入矩陣進行精確的輸出矩陣比較。

為了將其置於上下文中,研究者也透過相同和獨立的訓練運行,使用 Max difference 和 Wasserstein Distance 度量來量化模型權重在整個訓練過程中的差異。

對於訓練實驗,研究者則使用一種將文字輸入轉換為圖像的生成式 AI workload(即文字到圖像模型)。他們使用 Shutterstock 資料集重新訓練模型,並在一組英偉達 80GB A100 GPU 叢集上執行此實驗。

透過微基準量化數值偏差

#研究者首先分析了 Flash Attention 在前向傳遞過程中的影響。他們利用微基準測試,在隨機初始化查詢、鍵、值向量相同的情況下,檢驗不同數值精確度對 Attention 計算的輸出矩陣的影響。

如圖3 所示,當研究者使用從BF16 到FP64 變化的不同數值格式時,Flash Attention 和Baseline Attention 之間的數值偏差隨著尾數位數的增加而減小。這表明數值差異是由於較少的尾數位數所固有的近似造成的。

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

圖 3:數值格式對於 Flash Attention 的數值偏差所產生的效果。

之後,研究者為進行標準比較,在FP64 數值格式下的Baseline Attention 設定了「黃金值」,然後將不同數值格式下的Attention 輸出與該值進行了比較(如圖4 所示)。

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

圖 4:FP64 下 Baseline Attention「黃金值」的比較。

結果表明,Flash Attention 的數值偏差大約是 BF16 下 Baseline 的 10 倍。

為了進一步分析這種觀察到的數值偏差,研究者保持 tile 大小和 SRAM 大小不變的同時,掃描了矩陣的序列長度(如圖 5 所示)。

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

圖 5: 序列長度對 Flash Attention 數值偏差的影響。

如圖所示,隨著序列長度的增加,無論是透過(a)最大差異上限的測量,或是透過(b)差異的平均值和標準差的測量,Flash Attention和Baseline Attention 之間的數值偏差都在增加。

除此之外,研究者也利用微基準設計進行不同最佳化的實驗,以便更了解數值偏差的影響(如圖 6 所示)。

圖 6a 顯示了調換 block 維數的順序如何導致 Flash Attention 和 Baseline Attention 之間的數值差異增大。圖 6b 中的其他擾動,例如限制 tile 大小為正方形,不會對數值偏差產生影響。圖 6c 顯示了 block/tile 大小越大,數值偏差越小。

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

圖 6: 演算法的改變及其對觀察到的數值偏差的影響。

透過權重差異來了解數值偏差

雖然在前向傳遞過程中,Flash Attention 可能會導致Attention 輸出的數值偏差,但這項研究的最終目標是確定這是否會在模型訓練過程中產生任何影響,以研究它是否會導致訓練的不穩定性。

因此,研究者希望量化 Flash Attention 是否在訓練過程中改變了模型,即上文觀察到的 Attention 輸出差異是否反映在訓練過程中更新的模型權重中。

研究者利用兩個指標來測量使用 Baseline Attention 訓練的模型與使用 Flash Attention 訓練的模型之間的模型權重差異。首先計算最大差異,即找出權重矩陣之間差異的絕對值並取最大值,從而得出偏差的上限,如下所示:

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

雖然最大差值提供了數值偏差的上限,但它沒有考慮到每個矩陣的分佈。因此,研究者透過 Wasserstein Distance 來量化權重差異,這是衡量張量之間相似性的常用度量。雖然在計算上稍微複雜,但 Wasserstein Distance 包含了張量分佈的形狀資訊以衡量相似性。計算公式概述如下:

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

數值越低,表示矩陣之間的相似度越高。

利用這兩個指標,研究者隨後量化了在整個訓練過程中與Baseline Attention 相比,Flash Attention 的模型權重是如何變化的:

Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

#根據Wasserstein Distance 和Max Difference 這兩個指標,在整個訓練過程中,Flash Attention 的加入確實改變了模型權重,而且隨著訓練的繼續,這種差異只會越來越大,這顯示了使用Flash Attention 訓練的模型與使用Baseline Attention 訓練的相同模型收斂到了不同的模型。

然而,訓練是一個隨機過程,某些模型結構的改變可能會在下游效應和準確性方面產生相似的結果。即使使用 Flash Attention 和 Baseline Attention 訓練的模型權重不同,這也是值得關注的。

完全訓練模型並評估準確性是一項昂貴且資源密集的任務,特別是對於訓練需要數月的大模型來說。

研究者透過配置一個 proxy 來探索:

(a) 這些權重變化的意義有多大?

(b) 能否將其與其他廣泛採用的訓練優化中的標準權重變化聯繫起來?

為了實現這一目標,研究者設計了一系列實驗來比較在不同場景下,訓練過程中的權重差異是如何變化的。

除了對比使用 Flash Attention 和 Baseline Attention 的訓練過程外,他們還量化了在訓練開始時權重被初始化為不同隨機值的相同訓練過程中的權重差異。這提供了一個界限,因為隨機權重初始化是一種常用的技術,並且通常會產生等效的結果。

此外,研究者也測量了使用不同精度訓練的模型權重的變化。數值精確度(即 FP16 與 FP32)有可能導致下游變化,這作為確定了 Flash Attention 權重重要性的一個上限。

如圖8 所示,可以發現,使用Flash Attention 的模型權重偏差變化率與不同模型初始化的權重偏差變化率相當或更小(注意紅色和藍色曲線的斜率)。

此外,使用 FP16 與 FP32 時的權重變化率比不同模型初始化時的權重變化率更高,變化也更大。

這些結果提供了一個proxy,並表明:「雖然Flash Attention 會出現數值偏差,但它會被隨機模型初始化和低精度訓練所限制。而且所引入的模型權重偏差大約是低精度訓練時的1/2 至1/5 倍。相對權重差異。

更多研究細節,可參考原文。 Flash Attention稳定吗?Meta、哈佛发现其模型权重偏差呈现数量级波动

以上是Flash Attention穩定嗎? Meta、哈佛發現其模型權重偏差呈現數量級波動的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn