首頁 >科技週邊 >人工智慧 >650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

王林
王林轉載
2023-06-20 15:57:581491瀏覽

在大模型方向上,科技巨頭在訓更大的模型,學界則在想辦法搞最佳化。最近,優化算力的方法又上升到了新的高度。

大型語言模型(LLM)徹底改變了自然語言處理(NLP)領域,展現了湧現、頓悟等非凡能力。然而,若想建構出具備一定通用能力的模型,就需要數十億參數,這大幅提高了 NLP 研究的門檻。在 LLM 模型調優過程中通常又需要昂貴的 GPU 資源,例如 8×80GB 的 GPU 設備,這使得小型實驗室和公司很難參與這一領域的研究。

最近,人們正在研究參數高效的微調技術(PEFT),例如 LoRA 和 Prefix-tuning,為利用有限資源對 LLM 進行調優提供了解決方案。然而,這些方法並沒有為全參數微調提供實用的解決方案,而全參數微調已被公認為比參數高效微調更強大的方法。

在上週復旦大學邱錫鵬團隊提交的論文《Full Parameter Fine-tuning for Large Language Models with Limited Resources》中,研究人員提出了一種新的優化器LOw- Memory Optimization(LOMO)。

透過將 LOMO 與現有的記憶體節省技術集成,與標準方法(DeepSpeed 解決方案)相比,新方法將記憶體使用量減少到了先前的 10.8%。因此,新方法能夠在一台具有 8×RTX 3090 的機器上對 65B 模型進行全參數微調,每個 RTX 3090 具有 24GB 記憶體。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

論文連結:https://arxiv.org/abs/2306.09782

#在該工作中,作者分析了LLM 中記憶體使用的四個面向:啟動、最佳化器狀態、梯度張量和參數,並對訓練過程進行了三方面的最佳化:

  1. 從演算法的角度重新思考了優化器的功能,發現SGD 在微調LLM 完整參數方面是一種很好的替代品。這使得作者可以刪除優化器狀態的整個部分,因為 SGD 不會儲存任何中間狀態。
  2. 新提出的優化器 LOMO 將梯度張量的記憶體使用量減少到 O (1),相當於最大梯度張量的記憶體使用量。
  3. 為了使用 LOMO 穩定混合精度訓練,作者整合了梯度歸一化、損失縮放,並在訓練期間將某些計算轉換為全精度。

新技術讓記憶體的使用等於參數使用加上啟動和最大梯度張量。全參數微調的記憶體使用被推向了極致,其僅等同於推理的使用。這是因為 forward backward 過程的記憶體佔用應該不會比單獨的 forward 過程少。值得注意的是,在使用 LOMO 節省記憶體時,新方法確保了微調過程不受影響,因為參數更新過程仍然等同於 SGD。

該研究評估了 LOMO 的記憶體和吞吐量效能,顯示借助 LOMO,研究者在 8 個 RTX 3090 GPU 上就可以訓練 65B 參數的模型。此外,為了驗證 LOMO 在下游任務上的效能,他們應用 LOMO 來調優 SuperGLUE 資料集集合上 LLM 的全部參數。結果顯示了 LOMO 對具有數十億參數的 LLM 進行最佳化的有效性。

方法介紹

在方法部分,本文詳細介紹了 LOMO(LOW-MEMORY OPTIMIZATION)。一般而言,梯度張量表示一個參數張量的梯度,其大小與參數相同,這樣一來記憶體開銷較大。而現有的深度學習框架如 PyTorch 會為所有參數儲存梯度張量。現階段,儲存梯度張量有兩方面原因:計算最佳化器狀態以及歸一化梯度。

由於研究採用 SGD 作為最佳化器,因此沒有依賴梯度的最佳化器狀態,並且他們有一些梯度歸一化的替代方案。

他們提出了 LOMO,如演算法 1 所示,LOMO 將梯度計算與參數更新融合在一個步驟中,從而避免了梯度張量的儲存。

下圖為 SGD 和 LOMO 在反向傳播和參數更新階段的比較。 Pi 為模型參數,Gi 為 Pi 對應的梯度。 LOMO 將梯度計算和參數更新整合到一個步驟中,使梯度張量最小。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

LOMO 對應的演算法偽代碼:

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

具體而言,該研究將vanilla 梯度下降表示為

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

,這是一個兩步驟過程,首先是計算梯度,然後更新參數。融合版本為 

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

 此研究的關鍵想法是在計算梯度時立即更新參數,這樣就不會在記憶體中儲存梯度張量。這一步可以透過在向反向傳播中註入 hook 函數來實現。 PyTorch 提供了注入 hook 函數的相關 API,但卻無法以目前的 API 實現精確的即時更新。相反,該研究在記憶體中最多儲存一個參數的梯度,並隨著反向傳播逐一更新每個參數。本文方法減少了梯度的記憶體使用,從儲存所有參數的梯度到只儲存一個參數的梯度。

大部分 LOMO 記憶體使用與參數高效微調方法的記憶體使用一致,這表明 LOMO 與這些方法結合只會導致梯度佔用記憶體的輕微增加。這樣就可以為 PEFT 方法調優更多的參數。

實驗結果

在實驗部分,研究者從三個方面評估了他們提出的方法,即記憶體使用情況、吞吐量和下游性能。如果不作進一步解釋,所有的實驗都是用 7B 到 65B 的 LLaMA 模型進行的。

記憶體使用情況

#研究者首先剖析了,在不同設定下,訓練期間的模型狀態和啟動的記憶體使用情況。如表1 所示,與AdamW 優化器相比,LOMO 優化器的使用導致記憶體佔用大幅減少,從102.20GB 減少到14.58GB;與SGD 相比,在訓練LLaMA-7B 模型時,記憶體佔用從51.99GB減少到14.58GB。記憶體用量的大幅減少主要歸因於梯度和優化器狀態的記憶體需求減少。因此,在訓練過程中,記憶體大部分被參數佔據,與推理過程中的記憶體用量相當。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

如圖2 所示,若採用AdamW 最佳化器進行LLaMA-7B 訓練,則相當大比例的記憶體( 73.7%)被指派給優化器狀態。用 SGD 優化器取代 AdamW 優化器可以有效減少優化器狀態佔用記憶體的百分比,從而減輕 GPU 記憶體使用(從 102.20GB 減少到 51.99GB)。如果使用 LOMO,參數更新和 backward 會被融合到一個步驟中,進一步消除優化器狀態對記憶體的需求。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

吞吐量

#研究者比較了LOMO、AdamW 和SGD的吞吐性能。實驗是在一台配備了 8 個 RTX 3090 GPU 的伺服器上進行的。

对于 7B 的模型,LOMO 的吞吐量呈现显著优势,超过 AdamW 和 SGD 约 11 倍。这一重大改进可归功于 LOMO 在单个 GPU 上训练 7B 模型的能力,这减少了 GPU 间的通信开销。与 AdamW 相比,SGD 的吞吐量略高,这可归因于 SGD 排除了动量和方差的计算。

至于 13B 模型,由于内存的限制,它无法在现有的 8 个 RTX 3090 GPU 上用 AdamW 训练。在这种情况下,模型的并行性对 LOMO 来说是必要的,LOMO 在吞吐量方面仍然优于 SGD。这一优势归功于 LOMO 的内存高效特性,以及只需要两个 GPU 以相同的设置来训练模型,从而降低了通信成本,提高了吞吐量。此外,在训练 30B 模型时,SGD 在 8 个 RTX 3090 GPU 上遇到了内存不足(OOM)的问题,而 LOMO 在只有 4 个 GPU 的情况下表现良好。

最后,研究者使用 8 个 RTX 3090 GPU 成功训练了 65B 模型,实现了 4.93 TGS 的吞吐量。利用这样的服务器配置和 LOMO,模型在 1000 个样本上的训练过程(每个样本包含 512 个 token)大约需要 3.6 小时。

下游性能

为了评估 LOMO 在微调大型语言模型方面的有效性,研究者进行了一系列广泛的实验。他们将 LOMO 与其他两种方法进行比较,一种是不需要微调的 Zero-shot,另一种是目前很流行的参数高效微调技术 LoRA。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

表 3 结果显示:
  • LOMO 的表现明显好于 Zero-shot;
  • 在大多数实验中,LOMO 普遍优于 LoRA;
  • LOMO 可以有效扩展至 650 亿参数的模型。

LOMO 和 LoRA 在本质上是相互独立的。为了验证这一说法,研究者使用 LLaMA-13B 在 BoolQ 和 MultiRC 数据集上进行了实验。结果如图 3 所示。

他们发现,LOMO 在持续增强 LoRA 的性能,不管 LoRA 取得的结果有多高。这表明,LOMO 和 LoRA 采用的不同微调方法是互补的。具体来说,LOMO 专注于微调预训练模型的权重,而 LoRA 则调整其他模块。因此,LOMO 不会影响到 LoRA 的性能;相反,它有助于对下游任务进行更好的模型调优。

650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了

更多细节参见原论文。

以上是650億參數,8塊GPU就能全參數微調:邱錫鵬團隊把大模型門檻打下來了的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除