首頁 >科技週邊 >人工智慧 >PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

WBOY
WBOY轉載
2023-11-22 15:45:37940瀏覽
我們該如何優化 Meta 的「分割一切」模型,PyTorch 團隊撰寫的這篇部落格由淺入深的幫你解答。

從年初到現在,生成式 AI 發展迅速。但很多時候,我們又必須面對一個難題:如何加速生成式 AI 的訓練、推理等,尤其是在使用 PyTorch 的情況下。

本文 PyTorch 團隊的研究者為我們提供了一個解決方案。文章重點介紹如何使用純原生 PyTorch 加速生成式 AI 模型,此外,文章還介紹了 PyTorch 新功能,以及如何組合這些功能的實際範例。

結果如何? PyTorch 團隊表示,他們重寫了 Meta 的「分割一切」 (SAM) 模型,從而使程式碼比原始實現快 8 倍,並且沒有損失準確率,所有這些都是使用原生 PyTorch 進行優化的。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

部落格網址:https://pytorch.org/blog/accelerating-generative-ai/

看完本文,你將會了解到:

  • Torch.compile:PyTorch 模型編譯器, PyTorch 2.0 加入了一個新的函數,叫做torch. compile (),能夠透過一行程式碼對現有的模型進行加速;
  • GPU 量化:透過降低運算精度來加速模型;
  • SDPA(Scaled Dot Product Attention ):記憶體高效的注意力實作方式;
  • 半結構化(2:4) 稀疏性:一種針對GPU 最佳化的稀疏記憶體格式;
  • Nested Tensor:Nested Tensor 把{tensor, mask} 打包在一起,將非均勻大小的資料批次處理到單張量中,例如不同大小的圖片;
  • #Triton 自訂操作:使用Triton Python DSL 編寫GPU 操作,並透過自訂操作符註冊輕鬆將其整合到PyTorch 的各種元件中。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

                                中之後所帶來的吞吐量上與記憶特性上的吞吐量增加的吞吐量上以及減少記憶中產生的吞吐量增加。

SAM 由Meta 提出,關於這項研究的更多內容請參考「CV 不存在了?Meta 發布「分割一切」AI 模型,CV 或迎來GPT-3 時刻」。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

接下來,文章介紹了 SAM 最佳化過程,包括效能分析、瓶頸識別,以及如何將這些新功能整合進 PyTorch 以解決 SAM 面臨的這些問題。除此之外,本文也介紹了 PyTorch 的一些新功能:torch.compile、SDPA、Triton kernels、Nested Tensor 以及 semi-structured sparsity(半結構化稀疏)。

本文內容逐層深入,文章的最後會介紹快速版SAM,有興趣的小夥伴可以去GitHub 上下載,此外,本文也透過Perfetto UI 對這些數據進行了視覺化,以此來闡述PyTorch 每項特性的應用價值。

GitHub 網址:https://github.com/pytorch-labs/segment-anything-fast

對分割一切模型SAM 的重寫

該研究表示,本文利用的SAM 基準資料類型為float32 dtype、batch 大小為1,使用PyTorch Profiler 查看核心追蹤的結果如下:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

本文發現SAM 有兩個地方可以優化:

第一個是對aten::index 的長調用,這是由張量索引操作(例如[])產生的底層呼叫所導致的。然而實際上 GPU 花在 aten::index 上的時間相對較低,原因在於 aten::index 在啟動兩個核心的過程中,兩者之間發生了阻塞 cudaStreamSynchronize。這意味著 CPU 會等待 GPU 完成處理,直到啟動第二個核心。因而為了優化 SAM,本文認為應該致力於消除導致空閒時間的阻塞 GPU 同步。

第二個是 SAM 在矩陣乘法中花費了大量的 GPU 時間(上圖中的深綠色),這在 Transformers 中很常見。如果能夠減少 SAM 模型在矩陣乘法上花費的 GPU 時間,我們就可以顯著加快 SAM 的速度。

接下來本文用 SAM 的吞吐量 (img/s) 和記憶體開銷 (GiB) 來建立基準。之後就是優化過程了。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

Bfloat16 半精確度(加上GPU 同步與批次)

為了解決上述問題,即讓矩陣乘法花費的時間更少,本文轉向bfloat16。 Bfloat16 是常用的半精度類型,透過降低每個參數和啟動的精度,能夠節省大量的計算時間和記憶體。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

                             用bfloat16 替換padding 類型

此外,為了移除GPU 同步,本文發現有兩個位置可以最佳化。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

#具體來說(參考上圖比較容易理解,出現的變數名稱都在程式碼中),研究發現在SAM的影像編碼器中,有充當座標縮放器(coordinate scalers)的變數q_coords 和k_coords,這些變數都是在CPU 上分配和處理的。然而,一旦這些變數被用來在 rel_pos_resized 中建立索引,這些索引操作就會自動的將這些變數移至 GPU 上,這種複製會導致 GPU 同步。為了解決上述問題,研究注意到可以使用 torch.where 重寫這部分內容來解決問題,如上所示。

內核追蹤

#在應用了這些變更之後,本文注意到單一內核呼叫之間有著顯著的時間間隔,尤其在小批量(這裡為1)時更為突出。為了更深入的了解這一現象,本文開始對批次大小為8 的SAM 推理進行效能分析:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

在查看每個核心所花費的時間時,本文觀察到SAM 的大部分GPU 時間都花費在逐元素核心(elementwise kernels)和softmax 操作上。

現在可以看到矩陣乘法的相對開銷小了很多。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

將 GPU 同步和 bfloat16 最佳化結合在一起,SAM 效能提高了 3 倍。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

Torch.compile( graph breaks 和CUDA graphs)

#本文發現在深入研究SAM 的過程中有很多小的操作,他們認為使用編譯器來融合操作有很大的好處,因而PyTorch 對torch.compile 做了以下最佳化:

  • ##將nn.LayerNorm 或nn.GELU 等操作序列融合成單一的GPU 核心;
  • 融合緊接在矩陣乘法核心之後的操作,以減少GPU 核心呼叫的數量。

透過這些最佳化,研究減少了 GPU 全域記憶體往返次數(roundtrips),從而加快了推理速度。我們現在可以在 SAM 的圖像編碼器上嘗試 torch.compile。為了最大限度地提高效能,本文使用了一些進階編譯技術:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

#核心追蹤

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

結果顯示,torch.compile 運作得很好。 PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

可以觀察到 softmax 佔了很大一部分時間,然後是各種 GEMM 變體。以下測量的是批次大小為 8 及以上的變化。 PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

SDPA: scaled_dot_product_attention

##接下來,本文又對SDPA(scaled_dot_product_attention)進行了實驗,研究的重點是注意力機制。一般來講,原生注意力機制在時間和記憶體上隨序列長度呈二次方擴展。 PyTorch 的 SDPA 操作基於 Flash Attention、FlashAttentionV2 和 xFormer 的記憶體高效注意力原理構建,可以顯著加快 GPU 注意力。與 torch.compile 結合,這個操作允許在 MultiheadAttention 的變體中表達和融合一個共同的模式。經過一小部分變更後,現在模型可以使用 scaled_dot_product_attention。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

核心追蹤

#現在可以看到記憶體高效的注意力核心佔用了GPU 上大量的運算時間:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

使用PyTorch 的原生scaled_dot_product_attention,可以大幅增加批次大小。下圖為批次大小為 32 以上的變化。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

之後,研究又實驗了 Triton,NestedTensor 、批次 Predict_torch, int8 量化,半結構化 (2:4) 稀疏性等操作。

例如本文使用自訂 positional Triton 內核,觀察到批次大小為 32 的測量結果。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

使用 Nested Tensor,批次大小為 32 以上的變化。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

加入量化後,批次大小為 32 以上變化的測量結果。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

###
文章的最後是半結構稀疏性。研究表示,矩陣乘法仍然是需要面對的瓶頸。解決的辦法是使用稀疏化來近似矩陣乘法。透過稀疏矩陣(即將值歸零)可以使用更少的位元來儲存權重和激活張量。該研究將張量中哪些權重設為零的過程稱為剪枝。剪枝掉較小的權重可以潛在地減少模型大小,而不會顯著損失準確率。

剪枝的方法多種多樣,從完全非結構化到高度結構化。雖然非結構化剪枝理論上對精度的影響最小,但 GPU 在進行大型密集矩陣乘法方面儘管非常高效,然而在稀疏情況下可能還會遭受顯著的性能下降。 PyTorch 最近支持的一種剪枝方法旨在尋求平衡,稱為半結構化(或 2:4)稀疏性。這種稀疏儲存將原始張量減少了 50%,同時產生密集張量輸出。參見下圖的說明。

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

為了使用這種稀疏儲存格式和相關的快速內核,接下來要做的是剪枝權重。本文在 2:4 的稀疏度下選擇最小的兩個權重進行剪枝,將權重從預設的 PyTorch(“strided”)佈局更改為這種新的半結構化稀疏佈局很容易。要實作apply_sparse (model),只需要32 行Python 程式碼:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

在2:4 的稀疏度下,本文觀察到vit_b 和批次大小為32 時的SAM峰值性能:

PyTorch團隊重寫「分割一切」模型,比原始實現快8倍

最後,一句話總結這篇文章:本文介紹了迄今為止在PyTorch 上最快的Segment Anything 實現方式,借助官方發布的一系列新功能,本文在純PyTorch 中重寫了原始SAM,並且沒有損失準確率。

有興趣的讀者可以查看原始部落格以了解更多內容。

參考連結:https://pytorch.org/blog/accelerating-generative-ai/

以上是PyTorch團隊重寫「分割一切」模型,比原始實現快8倍的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:jiqizhixin.com。如有侵權,請聯絡admin@php.cn刪除