首頁  >  文章  >  科技週邊  >  英偉達玩轉剪枝、蒸餾:把Llama 3.1 8B參數減半,性能同尺寸更強

英偉達玩轉剪枝、蒸餾:把Llama 3.1 8B參數減半,性能同尺寸更強

PHPz
PHPz原創
2024-08-16 16:42:44247瀏覽
小模型崛起了。

上個月,Meta 發布了Llama 3.1 系列模型,其中包括Meta 迄今為止最大的405B 模型,以及兩個較小的模型,參數量分別為700億和80 億。

Llama 3.1 被認為是引領了開源新時代。然而,新一代的模型雖然效能強大,但部署時仍需要大量運算資源。

因此,業界出現了另一種趨勢,即開發小型語言模型 (SLM),這種模型在許多語言任務中表現足夠出色,部署起來也非常便宜。

最近,英偉達研究表明,結構化權重剪枝與知識蒸餾相結合,可以從初始較大的模型中逐步獲得較小的語言模型。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

                 .

經過剪枝和蒸餾,英偉達研究團隊將 Llama 3.1 8B 提煉為 Llama-3.1-Minitron 4B 開源了出來。這是英偉達在 Llama 3.1 開源系列中的第一個作品。

Llama-3.1-Minitron 4B 的表現優於類似大小的最先進的開源模型,包括 Minitron 4B、Phi-2 2.7B、Gemma2 2.6B 和 Qwen2-1.5B。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

這項研究的相關論文早在上個月就已經放出了。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强
  • 論文連結:https://www.arxiv.org/pdf/2407.14679

  • 論文

    論文標題:Compact Language Models via Pruning and Knowledge Distillation

剪枝和蒸餾

剪枝使模型變得更小、更簡,可以透過刪除層(深度剪枝)或刪除神經元和注意力頭以及嵌入通道(寬度剪枝)來實現。剪枝通常伴隨著一定程度的再訓練,以恢復準確率。

模型蒸餾是一種將知識從大型複雜模型(通常稱為教師模型)遷移到較小、較簡單的學生模型的技術。目標是創建一個更有效率的模型,該模型保留了原始較大模型的大部分預測能力,同時運行速度更快且資源消耗更少。

蒸餾方式主要包括兩種:SDG 微調與經典知識蒸餾,兩種蒸餾方式互補。本文主要關注經典知識蒸餾方法。

英偉達採用將剪枝與經典知識蒸餾相結合的方式來構造大模型,下圖展示了單一模型的剪枝和蒸餾過程(上)以及模型剪枝和蒸餾的鏈條(下) 。具體過程如下:

1. 英偉達從15B 模型開始,評估每個組件(層、神經元、頭和嵌入通道)的重要性,然後對模型進行排序和剪枝,使其達到目標大小:8B 模型。

2. 接著使用模型蒸餾進行了輕度再訓練,原始模型作為老師,剪枝後的模型作為學生。

3. 訓練結束後,以小模型(8B)為起點,剪枝和蒸餾為較小的 4B 模型。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

                       

需要注意的點是,在對模型剪枝之前,需要先了解模型的哪部分是重要的。英偉達提出了一種基於激活的純重要性評估策略,該策略可以同時計算所有相關維度(深度、神經元、頭和嵌入通道)的信息,使用一個包含1024 個樣本的小型校準數據集,並且只需要前向傳播。這種方法相比依賴梯度資訊並需要反向傳播的策略更加簡單且具有成本效益。 

在剪枝過程中,你可以針對給定軸或軸組合在剪枝和重要性估計之間進行迭代交替。實證研究顯示,使用單次重要性估計就足夠了,迭代估計不會帶來額外的好處。

利用經典知識蒸餾進行再訓練

下圖 2 展示了蒸餾過程,其中 N 層學生模型(剪枝後的模型)是從 M 層教師模型中(原始未剪枝模型)蒸餾而來。學生模型透過最小化嵌入輸出損失、logit 損失以及映射到學生區塊 S 和教師區塊 T 的 Transformer 編碼器特定損失組合來學習。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

                        

剪枝和蒸餾最佳實踐

英偉達基於緊湊語言模型中剪枝和知識蒸餾的廣泛消融研究,將自己的學習成果總結為以下幾種結構化壓縮最佳實務。

一是調整大小。

  • 要訓練一組 LLM,先訓練最大的一個,然後迭代地剪枝和蒸餾以獲得較小的 LLM。

  • 如果使用多階段訓練策略來訓練最大的模型,最好剪枝並對訓練最後階段獲得的模型進行重新訓練。

  • 對最接近目標大小的可用來源模型進行剪枝。

二是剪枝。

  • 優先考慮寬度剪枝而不是深度剪枝,這對於 15B 參數規模以下的模型效果很好。

  • 使用單一樣本(single-shot)重要性估計,因為迭代重要性估計沒有任何好處。

三是重新訓練。

  • 僅使用蒸餾損失進行重新訓練,而不是常規訓練。

  • 當深度明顯減少時,使用 logit、中間狀態和嵌入蒸餾。

  • 當深度沒有明顯減少時,使用 logit-only 蒸餾。

Llama-3.1-Minitron:將最佳實踐付諸應用

Meta 最近推出了強大的Llama 3.1 開源模型系列,在許多基準測試中可與閉源模型相媲美。 Llama 3.1 的參數範圍從巨大的 405B 到 70B、8B。

憑藉Nemotron 蒸餾的經驗,英偉達著手將Llama 3.1 8B 模型蒸餾為更小、更高效的4B 模型,採取以下措施:

  • 教師微調

  • 教師微調

    教師微調
  • 教師微調
  • Depth-only 剪枝
  • Width-only 剪枝
  • 準確率基準

準確率基準

效能基準

教師微調

為了修正模型訓練所基於的原始資料集的分佈偏差,英偉達首先在他們的資料集上(94B token)對未剪枝的8B 模型進行了微調。實驗表明,如果不糾正分佈偏差,教師模型在蒸餾時會為數據集提供次優指導。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强Depth-only 剪枝

為了從 8B 降到 4B,英偉達剪枝了 16 層(50%)。他們首先透過從模型中刪除每個層或連續子層組來評估它們的重要性,並觀察下游任務中 LM 損失的增加或準確率的降低。 下圖 5 顯示了刪除 1、2、8 或 16 層後驗證集上的 LM 損失值。例如,第 16 層的紅色圖表示如果刪除前 16 層,則出現 LM 損失。第 17 層表示如果保留第一層並刪除第 2 至第 17 層,也會出現 LM 損失。英偉達觀察到:開始和結束的層是最重要的。

                                  英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

然而,英偉達觀察到,這種 LM 損失不一定與下游表現直接相關。 下圖6 顯示了每個剪枝模型的Winogrande 準確率,它顯示最好刪除第16 到第31 層,其中第31 層是倒數第二層,剪枝模型的5-shot準確率明顯高於隨機準確率(0.5)。英偉達採納了這項見解,刪除了第 16 到第 31 層。

                        

  • Width-only 剪枝

  • 英偉達沿寬度軸剪枝了嵌入(隱藏)和MLP 中間維,以壓縮Llama 3.1 8B 。具體來說,他們使用前面描述的基於激活的策略來計算每個注意頭、嵌入通道和 MLP 隱藏維度的重要性分數。
  • 在重要性估計之後,英偉達選擇

  • 將 MLP 中間維從 14336 剪枝到 9216。

    將隱藏大小從 4096 剪枝到 3072。 重新訓練注意頭數量和層數。

值得一提的是,在單樣本剪枝之後,寬度剪枝的 LM 損失高於深度剪枝。然而,經過短暫的重新訓練後,趨勢發生了逆轉。

準確率基準

英偉達使用以下參數對模型進行蒸餾

  • 峰值學習率= 1e-4

  • 最小學習率= 1e-5

  • 40 步線性預熱

  • 餘弦衰減計畫

  • 全域批次大小= 1152

下表1 顯示了Llama-3.1-Minitron 4B 模型變體(寬度剪枝和深度剪枝)與原始Llama 3.1 8B 模型、其他類似大小的模型在跨多個領域的基準測試中的性能比較。整體而言,英偉達再次證實了寬度剪枝策略相較於遵循最佳實踐的深度剪枝的有效性。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

                        

為了驗證蒸餾後的模型是否可以成為強大的指令模型,英偉達使用 NeMo-Aligner 對 Llama-3.1-Minitron 4B 模型進行了微調。

他們使用了Nemotron-4 340B 的訓練數據,在IFEval、MT-Bench、ChatRAG-Bench 和Berkeley Function Calling Leaderboard (BFCL) 上進行了評估,以測試指令遵循、角色扮演、RAG 和函數呼叫功能。最後確認 Llama-3.1-Minitron 4B 模型可以成為可靠的指令模型,其表現優於其他基準 SLM。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

                       base

效能基準

英偉達利用NVIDIA TensorRT-LLM(一種用於最佳化LLM 推理的開源工具包)優化了Llama 3.1 8B和Llama-3.1-Minitron 4B 模型。

下兩張圖顯示了不同模型在不同用例下以FP8 和FP16 精度每秒的吞吐量請求,表示為8B 模型的batch size 為32 的輸入序列長度/ 輸出序列長度(ISL/ OSL) 組合以及4B 模型的batch size 為64 的輸入序列長度/ 輸出序列長度(ISL/OSL) 組合,這要歸功於在一塊英偉達H100 80GB GPU 上,較小的權重允許較大的batch size。

Llama-3.1-Minitron-4B-Depth-Base 變體是最快的,平均吞吐量約為Llama 3.1 8B 的2.7 倍,而Llama-3.1-Minitron-4B-Width-Base 變體的平均吞吐量約為Llama 3.1 8B 的1.8 倍。與 BF16 相比,在 FP8 中部署還可使這三種型號的效能提高約 1.3 倍。

英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强
英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

            32,Llama-3.1-Minitron 4B 型號為BS=64。1x H100 80GB GPU。

結論

剪枝和經典知識提煉是一種非常經濟高效的方法,可以逐步獲得更小尺寸的LLM,與在所有領域從頭開始訓練相比,可實現更高的準確性。與合成資料式微調或從頭開始預訓練相比,這是一種更有效且資料效率更高的方法。

Llama-3.1-Minitron 4B 是英偉達首次嘗試使用最先進的開源 Llama 3.1 系列完成的探索。要在 NVIDIA NeMo 中使用 Llama-3.1 的 SDG 微調,可參閱 GitHub 上的 /sdg-law-title-generation 部分。

有關更多信息,請參閱以下資源:

  • https://arxiv.org/abs/2407.14679

  • https://github.com/NVlabs/Minitron
  • https://huggingface.co/nvidia/Llama-3.1-Minitron-4B-Width-Base
  • https://huggingface.co/nvidia/Llama-3.1-Minitron-4B-Depth-Base

參考鏈接:

https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3 -1-minitron-4b-model/

以上是英偉達玩轉剪枝、蒸餾:把Llama 3.1 8B參數減半,性能同尺寸更強的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn