搜尋
首頁科技週邊人工智慧如何在GPU資源受限情況下微調超大模型

如何在GPU資源受限情況下微調超大模型

問題:模型大小超過GPU 容量怎麼辦? 

本文的靈感來自Yandex資料分析學院所教授的「高效能深度學習系統」課程。

預備知識:#假設讀者已經了解神經網路的前傳遞和後向傳遞的工作原理,這對理解本文內容至關重要。文中使用PyTorch作為框架。

開始吧!

當試圖使用大型模型(即aka gpt-2-xl),它帶有5億多個參數,而你的GPU 資源受限,無法將它安裝到GPU上運行,或在模型訓練期間無法實現論文中定義的批次大小,此時該怎麼辦?也許可以選擇放棄,使用一個更輕量級版本的模型,或者減少訓練的批次大小,這樣的話,便無法獲得論文中描述的訓練結果。

但是,有一些技巧可以幫助解決上述問題。

下面來討論一些方法,即如何利用這些方法來微調具有15億個參數的GPT-2-XL模型。

問題的核心

#首先,來了解一下將模型載入到GPU中所需GPU記憶體問題的實質。

假設模型有 個FP32(32位元浮點)參數,需要在GPU上訓練這個模型,例如,執行Adam優化器。

透過計算,結果令人震驚。

 如何在GPU資源受限情況下微調超大模型

#假設已有12 GB記憶體的NVIDIA GeForce RTX 3060。首先, 1e9個FP32參數約佔4 GB的GPU記憶體。

同樣,對於梯度,也將保留相同數量的記憶體。所以,總共已經保留了8 GB的內存,由於還沒有開始訓練,也沒有載入優化器,載入優化器也同樣需要一定數量的內存。 Adam優化器需要為每個參數儲存第一備份和第二個備份,即需要8 GB額外記憶體。算下來,必須有大約16 GB的GPU內存,才能正確地將模型加載到GPU上,在本文的例子中,GPU只有12 GB的空閒內存。看起來很不妙,對吧?

然而,可以透過一些方法來嘗試解決這個問題,以下是相關內容:

  • 梯度累積/微批量;
  • 梯度檢查點;
  • #模型並行訓練;
  • 管道作業;
  • 張量並行化
混合精確度訓練;

記憶體卸載;

如何在GPU資源受限情況下微調超大模型#優化器8位元量化。

##################接下來,將詳細解讀這些技巧。 ###############開始################

問題:模型比GPU容量大,怎麼辦?

  • 簡單模式:無法適配批次大小為1
  • 專業模式:參數也沒辦法適應

#概述

如何在GPU資源受限情況下微調超大模型

#如果模型大於GPU容量,即便將批次大小設為1都不夠,那該怎麼辦呢?有一個解決方案,就是設定梯度檢查點,下面來看看這個概念。對於一個簡單的包含n層的前饋神經網路來說,梯度的計算圖如下:

 

神經網路層的活化對應於用f標記的節點,在正向傳遞期間,按順序對所有這些節點進行計算。對應於這些層的活化和參數的損失梯度以b標記的節點表示。在反向傳遞期間,所有這些節點都以相反的順序進行計算。 f個節點的計算結果用於計算b個節點,因此所有f個節點在向前傳遞後都保存在記憶體中。只有當反向傳播進展到足以計算出f節點的所有依賴關係時,它才能從記憶體中擦除。這意味著:簡單的反向傳播所需的記憶體隨神經網路層數n的變化呈線性增長。 如何在GPU資源受限情況下微調超大模型

下面是這些節點的計算順序,紫色陰影圓圈表示在給定時間需要將哪個節點儲存到記憶體之中。  

#梯度檢查點

如何在GPU資源受限情況下微調超大模型

如上所述的簡單反向傳播在運算方面是最優的:它只計算每個節點一次。但是,如果重新計算節點,可能會節省大量記憶體。例如,可以簡單地重新計算每個節點。執行的順序與所使用的記憶體如下圖所示:

 

如何在GPU資源受限情況下微調超大模型這種策略在記憶體方面是最優的。但是,請注意,節點計算的數量進行了n²次縮放,而先前的縮放係數為n:每個n個節點都按n次順序重新計算。由於計算速度較慢,這種方法並不適用於深度學習。

為了在記憶體和運算之間取得平衡,需要提出一個策略,允許重新計算節點,但次數不要太頻繁。在這裡使用這樣一種策略:將神經網路激活的子集標記為檢查點節點。 

在本範例中,選取將第sqrt(n)個節點標記為檢查點。這樣,檢查點節點的數量和檢查點之間的節點數量都在sqrt(n)之間,這意味著:所需的記憶體量也按n的順序進行了縮放。此策略所需的額外計算量相當於網路單次前向傳遞所需的計算量。

如何在GPU資源受限情況下微調超大模型程式:

################在學習了梯度檢查點的細節之後,來看看如何在PyTorch中應用這個概念,看起來並不太難: ##################

梯度累積/微批次

如何在GPU資源受限情況下微調超大模型

#概述

##深度學習模型正在越變越大,很難在GPU記憶體中安裝這麼大型的神經網路。因此,被迫在訓練時選用較小的批次大小,它可能導致較慢的收斂和較低的準確性。

什麼是梯度累積?

在訓練神經網路時,通常會將資料分批量處理,神經網路預測批次標籤,用於計算相對於實際目標的損失。接下來,執行反向傳遞計算出梯度,並更新模型權值。梯度累積對訓練過程的最後一步進行了修正:在繼續下一個小批之前,保存梯度值,並將新的梯度添加到先前保存的梯度中,用這種方法取代更新每個小批的網路權重。只有在模型處理了幾個小批次後,才會更新權重。梯度累積模擬了一個更大的批次大小,如果想在一個小批次中使用64張影像,如果批次大小超過了8,則會報「CUDA記憶體出錯…」。在這種情況下,可以使用8批影像,並在模型處理64/8=8批後更新一次權重。如果你從這8個批次中累積每一個梯度,結果將是(幾乎)相同的,這樣便能夠執行訓練啦!

如何在GPU資源受限情況下微調超大模型

程式:

如何在GPU資源受限情況下微調超大模型#沒有梯度累積的標準訓練環通常為: 

在PyTorch中,梯度累積可以很容易完成。模型利用accumulation_steps處理完成小批之後,便可以執行最佳化。也可以利用accumulation_steps根據損失函數的性質來分割運行損失: 

#真漂亮,對嗎?當呼叫loss.backward() 時計算梯度,並由PyTorch累積,直到呼叫optimizer.zero_grad()時停止。

重點#某些網路體系結構使用專用的批次運算,如BatchNorm,當使用相同的批次大小時,結果可能會略有不同。

混合精確度訓練

#概述

  • ##混合精準度訓練是指將部分或全部FP32參數轉換為更小的格式,如FP16、TF16(浮點張量)或BF16(浮點位元組)。

主要優點

#混合精準度訓練的主要優點是:

減少記憶體使用;如何在GPU資源受限情況下微調超大模型

#######效能提速(更高的算術強度或更小的通訊佔用);##################使用專用硬體進行更快地計算。 #####################目前只對第一個優勢感興趣-減少記憶體的使用量,來看看如何使用PyTorch模型實現它。 ##################程式:################### ########### #

結果,在完成.half()操作之後,模型變小了2倍。將模型轉換為不同的格式(即BF16,TF16)後的縮放損失,將在後續的文章中討論。有些操作在FP16中是無法完成的,例如Softmax。 PyTorch可利用torch.autocast 來處理這些特殊情況。

8位元優化器

#增加模型尺寸是獲得更佳效能的有效方法。然而,訓練大模型時需要儲存模型、梯度和最佳化器的狀態(例如,Adam的指數平滑和及先前梯度的平方和),所有這些都儲存在數量有限的可用記憶體之中。

將32位元優化器降到8位元優化器,將數值的範圍從2³²減少到僅2⁸=256,將對優化器預留的記憶體數量產生巨大的影響。

研究人員提出了一個新的8位元Adam優化器,論文作者在文中這麼說: 「它將32位元的表現維持到部分原始內存中」。

8位元優化器有三個組成部分:(1)區塊級量化,隔離異常值,將誤差均勻分配給每一個位元;(2 )動態量化,高精度地量化小值和大值;(3)穩定的嵌入層,以提高詞嵌入優化模型的穩定性。

有了這些元件,可直接使用8位元狀態執行最佳化。將8位優化器狀態量化為32位,執行更新,然後再將狀態量化為8位元進行儲存。在暫存器中逐元素進行8位到32位的轉換,無需慢速複製到GPU記憶體或額外的臨時記憶體中執行量化和去量化。對GPU來說,這意味著8位元優化器要快於常規的32位元優化器。

來看看使用8位Adam之後,鼓舞人心的結果:

如何在GPU資源受限情況下微調超大模型

可以看出,使用量化的Adam可以節省大約8.5 GB的GPU內存,看起來相當棒!

了解它的可用性之後,再來看看如何用python實作它。

由Facebook提供的Bitsandbytes 套件是一個圍繞CUDA自訂函數的輕量級包裝器,封裝了8位元優化器和量化函數,利用它可以實現8位Adam的使用。

程式:

如何在GPU資源受限情況下微調超大模型

如上所述,量化優化器的使用非常簡單,結果也不錯。

綜合上述全部方法,對GPU上的GPT-2-XL進行微調。

最後,在掌握了上述方法之後,利用這些方法來解決實際問題,對擁有15億個參數的GPT-2-XL模型進行微調。顯然,無法將它載入到12 GB記憶體的NVIDIA GeForce RTX 3060 GPU之上。列出可以使用的全部方法:

  • 梯度檢查點;
  • #混合精度訓練(我設了一個技巧:使用相同模型的兩個樣本。首先,用.half將它加載到GPU上,將其命名為gpu_model;其次,在CPU上,將其命名為cpu_model。評估好GPU模型之後,將gpu_model的梯度加載到cpu_model中,運行optimizer. step(),將更新後的參數載入到gpu_model上);
  • 使用batch_size=64,minibatch_size=4的梯度累積,需要透過accumulation_steps來縮放損失;
  • 8位元Adam優化器。

把上述方法全部利用起來,看看程式碼: 

如何在GPU資源受限情況下微調超大模型

利用上述所有方法之後,在GPU上實現了對16GB的GPT-2-XL模型微調,絕了!

結論

在本博中,給出了高效能使用記憶體的關鍵概念,它適用於多種艱鉅的任務,如上所述。將在後續的文章中討論其他概念。衷心感謝,撥冗閱讀本文!

以上是如何在GPU資源受限情況下微調超大模型的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述
本文轉載於:51CTO.COM。如有侵權,請聯絡admin@php.cn刪除
外推指南外推指南Apr 15, 2025 am 11:38 AM

介紹 假設有一個農民每天在幾週內觀察農作物的進展。他研究了增長率,並開始思考他的植物在幾週內可以生長的高度。從Th

軟AI的興起及其對當今企業的意義軟AI的興起及其對當今企業的意義Apr 15, 2025 am 11:36 AM

軟AI(被定義為AI系統,旨在使用近似推理,模式識別和靈活的決策執行特定的狹窄任務 - 試圖通過擁抱歧義來模仿類似人類的思維。 但是這對業務意味著什麼

為AI前沿的不斷發展的安全框架為AI前沿的不斷發展的安全框架Apr 15, 2025 am 11:34 AM

答案很明確 - 只是雲計算需要向雲本地安全工具轉變,AI需要專門為AI獨特需求而設計的新型安全解決方案。 雲計算和安全課程的興起 在

生成AI的3種方法放大了企業家:當心平均值!生成AI的3種方法放大了企業家:當心平均值!Apr 15, 2025 am 11:33 AM

企業家,並使用AI和Generative AI來改善其業務。同時,重要的是要記住生成的AI,就像所有技術一樣,都是一個放大器 - 使得偉大和平庸,更糟。嚴格的2024研究O

Andrew Ng的新簡短課程Andrew Ng的新簡短課程Apr 15, 2025 am 11:32 AM

解鎖嵌入模型的力量:深入研究安德魯·NG的新課程 想像一個未來,機器可以完全準確地理解和回答您的問題。 這不是科幻小說;多虧了AI的進步,它已成為R

大語言模型(LLM)中的幻覺是不可避免的嗎?大語言模型(LLM)中的幻覺是不可避免的嗎?Apr 15, 2025 am 11:31 AM

大型語言模型(LLM)和不可避免的幻覺問題 您可能使用了諸如Chatgpt,Claude和Gemini之類的AI模型。 這些都是大型語言模型(LLM)的示例,在大規模文本數據集上訓練的功能強大的AI系統

60%的問題 -  AI搜索如何消耗您的流量60%的問題 - AI搜索如何消耗您的流量Apr 15, 2025 am 11:28 AM

最近的研究表明,根據行業和搜索類型,AI概述可能導致有機交通下降15-64%。這種根本性的變化導致營銷人員重新考慮其在數字可見性方面的整個策略。 新的

麻省理工學院媒體實驗室將人類蓬勃發展成為AI R&D的核心麻省理工學院媒體實驗室將人類蓬勃發展成為AI R&D的核心Apr 15, 2025 am 11:26 AM

埃隆大學(Elon University)想像的數字未來中心的最新報告對近300名全球技術專家進行了調查。由此產生的報告“ 2035年成為人類”,得出的結論是,大多數人擔心AI系統加深的採用

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.聊天命令以及如何使用它們
4 週前By尊渡假赌尊渡假赌尊渡假赌

熱工具

DVWA

DVWA

Damn Vulnerable Web App (DVWA) 是一個PHP/MySQL的Web應用程序,非常容易受到攻擊。它的主要目標是成為安全專業人員在合法環境中測試自己的技能和工具的輔助工具,幫助Web開發人員更好地理解保護網路應用程式的過程,並幫助教師/學生在課堂環境中教授/學習Web應用程式安全性。 DVWA的目標是透過簡單直接的介面練習一些最常見的Web漏洞,難度各不相同。請注意,該軟體中

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

MantisBT

MantisBT

Mantis是一個易於部署的基於Web的缺陷追蹤工具,用於幫助產品缺陷追蹤。它需要PHP、MySQL和一個Web伺服器。請查看我們的演示和託管服務。

SublimeText3 英文版

SublimeText3 英文版

推薦:為Win版本,支援程式碼提示!

mPDF

mPDF

mPDF是一個PHP庫,可以從UTF-8編碼的HTML產生PDF檔案。原作者Ian Back編寫mPDF以從他的網站上「即時」輸出PDF文件,並處理不同的語言。與原始腳本如HTML2FPDF相比,它的速度較慢,並且在使用Unicode字體時產生的檔案較大,但支援CSS樣式等,並進行了大量增強。支援幾乎所有語言,包括RTL(阿拉伯語和希伯來語)和CJK(中日韓)。支援嵌套的區塊級元素(如P、DIV),