首頁 >科技週邊 >人工智慧 >徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

WBOY
WBOY原創
2024-07-17 16:08:17608瀏覽

從 125M 到 1.3B 的大模型,效能都有提升。


難以置信,這件事終於發生了。

一種全新的大語言模型(LLM)架構有望取代至今在 AI 領域如日中天的 Transformer,效能也比 Mamba 更好。本週一,有關 Test-Time Training(TTT)的論文成為了人工智慧社群熱議的話題。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

論文連結:https://arxiv.org/abs/2407.04620

該研究的作者來自史丹佛大學、加州大學柏克萊分校、加州大學聖迭戈分校
研究的作者來自史丹佛大學、加州大學柏克萊分校、加州大學聖迭戈分校
。他們設計了一種新架構 TTT,用機器學習模型取代了 RNN 的隱藏狀態。該模型透過輸入 token 的實際梯度下降來壓縮上下文。
徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態該研究作者之一 Karan Dalal 表示,他相信這將根本性的改變語言模型方法。

在機器學習模型中,TTT 層直接取代Attention,並透過表達性記憶解鎖線性複雜性架構,使我們能夠在上下文中訓練具有數百萬(有時是數十億)個token 的LLM 。 

作者在 125M 到 1.3B 參數規模的大模型上進行了一系列對比發現,TTT-Linear 和 TTT-MLP 均能匹敵或擊敗最強大的 Transformers 和 Mamba 架構方法。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態TTT 層作為一種新的資訊壓縮和模型記憶機制,可以簡單地直接取代 Transformer 中的自註意力層。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態與Mamba 相比,TTT-Linear 的困惑度更低,FLOP 更少(左),對長上下文的利用更好(右):

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態這不僅在理論上是線性的複雜度,而且實際運作時間也更快。

  • 在論文上線後,作者公開了代碼與jax 以供人們訓練和測試:https://github.com/test-time-training/ttt-lm-jax
還有PyTorch 推理程式碼:https://github.com/test-time-training/ttt-lm-pytorch

方法介紹到本質上的固有挑戰的:與自註意力機制不同,RNN 層必須將上下文壓縮為固定大小的隱藏狀態,更新規則需要發現數千甚至數百萬個token 之間的底層結構和關係。
研究團隊首先觀察到自監督學習可以將大量訓練集壓縮為 LLM 等模型的權重,而 LLM 模型通常表現出對其訓練資料之間語義聯繫的深刻理解。
受此觀察的啟發,研究團隊設計了一類新的序列建模層,其中隱藏狀態是一個模型,更新規則是自監督學習的一個步驟。由於更新測試序列上的隱藏狀態的過程相當於在測試時訓練模型,因此研究團隊將這種新的層稱為測試時訓練(Test-Time Training,TTT)層。

研究團隊引入兩個簡單的實例:TTT-Linear 和 TTT-MLP,其中隱藏狀態分別是線性模型和兩層 MLP。 TTT 層可以整合到任何網路架構中並進行端對端優化,類似於 RNN 層和自註意力。 徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

為了讓TTT 層更加高效,該研究採取了一些技巧來改進TTT 層:

首先,類似於在常規訓練期間對小批量序列採取gradient step 以獲得更好的並行性,該研究在TTT 期間使用小批量token。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

其次,該研究為每個 TTT 小批量內的操作開發了一種雙重形式,以更好地利用現代 GPU 和 TPU。雙重形式的輸出與簡單實現等效,但訓練速度快了 5 倍以上。如圖 3 所示,TTT-Linear 在 8k 上下文中比 Transformer 更快,與 Mamba 相當。

研究團隊認為:所有序列建模層都可以看作將歷史上下文儲存到隱藏狀態,如圖 4 所示。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

例如,RNN 層(如 LSTM、RWKV 和 Mamba 層)將上下文壓縮為跨時間的固定大小狀態。這種壓縮會產生兩種後果:一方面,將輸入標記 x_t 對應到輸出 token z_t 是高效率的,因為每個 token 的更新規則和輸出規則都需要恆定的時間。另一方面,RNN 層在長上下文中的效能受限於其隱藏狀態 s_t 的表現力。

自註意力也可以從上述角度來看待,只不過它的隱藏狀態(通常稱為 Key-Value 快取)是一個隨 t 線性增長的列表。它的更新規則只是將目前的 KV 元組(tuple)追加到該清單中,而輸出規則則掃描 t 前的所有元組,以形成注意力矩陣。隱藏狀態明確地儲存了所有歷史上下文,無需壓縮,這使得自註意力在長上下文方面比 RNN 層更具表現力。然而,掃描這個線性成長的隱藏狀態所需的時間也是線性成長的。為了保持長上下文的效率和表現力,研究者需要一種更好的壓縮啟發式。具體來說,需要將成千上萬或可能數百萬的 token 壓縮到一個隱藏狀態中,以有效捕捉它們的底層結構和關係。這聽起來似乎有些高難度,但其實很多人對這種啟發式非常熟悉。

骨幹架構。將任何 RNN 層整合到更大架構中的最簡潔方法是直接取代 Transformer 中的自註意力,這裡稱為骨幹。然而,現有的 RNN(如 Mamba 和 Griffin 等)都使用了與 Transformer 不同的骨幹層。最值得注意的是,它們的骨幹層在 RNN 層之前包含了時間卷積,這可能有助於收集跨時間的局部資訊。在對 Mamba 主幹網進行試驗後,研究者發現它也能改善 TTT 層的困惑度,因此將其納入了建議方法中,詳見圖 16。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

實驗結果

在實驗中,研究者將 TTT-Linear 、 TTT-MLP 與 Transformer、Mamba 這兩種基線進行了比較。

短文本

從圖11 可以得出以下結論:

(Mambao大多重疊。在 FLOP 預算較大的情況下,TTT-MLP (M) 的效能稍差。儘管 TTT-MLP 在各種模型大小下都比 TTT-Linear 有更好的困惑度,但 FLOPs 的額外成本抵消了這一優勢。
  • 8k 上下文,TTT-Linear (M) 和 TTT-MLP (M) 的表現都明顯優於 Mamba,這與 2k 上下文中的觀察結果截然不同。即使是使用 Transformer 主幹網路的 TTT-MLP (T) 在 1.3B 左右也比 Mamba 略勝一籌。一個顯著現像是,隨著上下文長度的增加,TTT 層相對於 Mamba 層的優勢也在擴大。
  • 上下文長度達到 8k,Transformer 在每種模型尺寸下的困惑度依舊表現不錯,但由於 FLOPs 成本的原因,已不具競爭力。
上圖結果展示了將 TTT 層從 Mamba 主幹網路切換到 Transformer 主幹網路的影響。研究者假設,當序列建模層的隱藏狀態表現力較低時,Mamba 主幹網路中的時序卷積會更有幫助。線性模型的表現力低於 MLP,因此從卷積中獲益更多。

長文本:書籍

為了評估長上下文的能力,研究者使用 Pile 的一個流行子集3,以 2 倍的增量對 1k 到上下文 32k Books3,以 2 倍的增量對 1k 到上下文長 32 的上下文進行實驗。這裡的訓練方法與 Pile 相同,而 TTT 層的所有實驗都在一次訓練運行中完成。從圖12 的結果子集,他們得出了以下觀察結果:

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

在Books 的2k 上下文中,Pile 2k 的所有觀察結果仍然成立,只是Mamba 現在的表現略好於TTT-Linear(而它們的線條在Pile 2k 中大致重疊)。

在 32k 上下文中,TTT-Linear (M) 和 TTT-MLP (M) 的表現都優於 Mamba,類似於 Pile 8k 的觀察結果。即使是採用 Transformer 主幹的 TTT-MLP (T) 在 32k 上下文中的表現也略優於 Mamba。

TTT-MLP (T) 在 1.3B 尺度下僅略差於 TTT-MLP (M)。如上所述,由於缺乏清晰的線性擬合,很難得出經驗縮放定律。然而,TTT-MLP (T) 的強勁趨勢表明,Transformer 主幹可能更適合更大的模型和更長的上下文,超出了我們的評估範圍。

時鐘時間

LLM 的訓練和推理可分解為前向、後向和生成。推理過程中的提示詞處理(也稱為預填充)與訓練過程中的前向運算相同,只是後向操作不需要儲存中間激活值。

由於前向(訓練和推理過程中)和後向都可以並行處理,因此這裡使用了雙重形式。產生新 token(也稱為解碼)本質上是順序性的,因此這裡使用了原始形式。

研究者提到,由於資源限制,本文實驗使用 JAX 編寫,並在 TPU 上運行。在 v5e-256 TPU pod 上,Transformer 基線在上下文為 2k 的情況下每次迭代訓練需要 0.30 秒,而 TTT-Linear 每次迭代需要 0.27 秒,在沒有任何系統優化的情況下快了 10%。鑑於 Mamba(以 PyTorch、Triton 和 CUDA 實現)只能在 GPU 上運行,為了進行公平比較,研究者將本文方法進行初步系統優化,使其能在 GPU 上運行。

圖 15 左側顯示了各個模型的前向內核在批次大小為 16 時的延遲。所有模型都是 1.3B(Mamba 為 1.4B)。值得注意的是,這裡的 Transformer 基線比 Mamba 論文中的快得多,因為此處使用了 vLLM ,而不是 HuggingFace Transformer 。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

此外,研究者還編寫了另一個用於生成的 GPU 內核,並在圖 15 右側以批次大小 512 為基準測試其速度。另一個常用的掛鐘時間(wall-clock time)指標是吞吐量(throughput),它考慮了使用更大的批次大小的潛在好處。對於吞吐量,上述所有觀察結果和方法之間的排序仍然有效。

主要作者

在 TTT 研究提交後,論文作者之一,UCSD 助理教授 Xiaolong Wang 發推表示祝賀。他表示,TTT 的研究持續了一年半,但測試時間訓練(TTT)這個想法從誕生到現在其實已經過了五年。雖然當初的想法和現在的成果完全不同了。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

TTT 論文的三位主要作者分別來自於史丹佛、UC Berkeley 和 UCSD。

其中 Yu Sun 是史丹佛大學的博士後,博士畢業於 UC Berkeley EECS,長期以來一直的研究方向就是 TTT。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

Xinhao Li 是 UCSD 在讀博士,他本科畢業於電子科技大學。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

Karan Dalal 是 UC Berkeley 在讀博士,他曾在高中時與他人共同創辦了一家名為 Otto 的獸醫遠距醫療新創公司。

徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態

上述三人,都把 test-time training 寫在了個人網站介紹研究方向的第一行。

更多研究細節,可參考原論文。

以上是徹底改變語言模型:全新架構TTT超越Transformer,ML模型取代RNN隱藏狀態的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn