首頁 >科技週邊 >人工智慧 >Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

王林
王林轉載
2024-04-01 19:46:111313瀏覽

Google力推的JAX在最近的基準測試中效能已經超過Pytorch和TensorFlow,7項指標排名第一。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

而且測試並不是JAX效能表現最好的TPU上完成的。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

雖然現在在開發者中,Pytorch依然比Tensorflow更受歡迎。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

但未來,也許有更多的大模型會基於JAX平台進行訓練和運行。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

模型

最近,Keras團隊為三個後端(TensorFlow、JAX、PyTorch)與原生PyTorch實作以及搭配TensorFlow的Keras 2進行了基準測試。

首先,他們為生成式和非生成式人工智慧任務選擇了一組主流的電腦視覺和自然語言處理模型:

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

對於模型的Keras版本,其採用了KerasCV和KerasNLP中已有的實作進行建構。而對於原生的PyTorch版本,則選擇了網路上最受歡迎的幾個選項:

- 來自HuggingFace Transformers的BERT、Gemma、Mistral

#- 來自HuggingFace Diffusers的StableDiffusion

- 來自Meta的SegmentAnything

#他們將這組模型稱為「Native PyTorch」,以便與使用PyTorch後端的Keras 3版本進行區分。

他們對所有基準測試都使用了合成數據,並在所有LLM訓練和推理中使用了bfloat16精度,同時在所有LLM訓練中使用了LoRA(微調)。

根據PyTorch團隊的建議,他們在原生PyTorch實作中使用了torch.compile(model, mode="reduce-overhead")(由於不相容,Gemma和Mistral訓練除外)。

為了衡量開箱即用的效能,他們使用高階API(例如HuggingFace的Trainer()、標準PyTorch訓練循環和Keras model.fit()),並盡可能減少配置。

硬體配置

所有基準測試都使用Google Cloud Compute Engine進行,配置為:一塊擁有40GB記憶體的NVIDIA A100 GPU、12個虛擬CPU和85GB的主機記憶體。

基準測試結果

表2顯示了基準測試結果(以步/毫秒為單位)。每個步驟都涉及對單一資料批次進行訓練或預測。

結果是100步的平均值,但排除了第一步,因為第一步包括了模型創建和編譯,這會額外花費時間。

為了確保比較的公平性,對於相同的模型和任務(不論是訓練還是推理)都使用相同的批次大小。

然而,對於不同的模型和任務,由於它們的規模和架構有所不同,可根據需要調整資料批大小,從而避免因過大而導致記憶體溢出,或是批過小而導致GPU使用不足。

過小的批次大小也會使PyTorch看起來較慢,因為會增加Python的開銷。

對於大型語言模型(Gemma和Mistral),測試時也使用了相同的批次大小,因為它們是相同類型的模型,具有類似數量的參數(7B)。

考慮到使用者對單批文字產生的需求,也對批次大小為1的文字產生情況進行了基準測試。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

關鍵發現

發現1

#不存在「最優」後端。

Keras的三種後端各展所長,重要的是,就效能而言,並沒有哪一個後端能夠始終勝出。

選擇哪個後端最快,往往取決於模型的架構。

這一點突顯了選擇不同框架以追求最佳效能的重要性。 Keras 3可以協助輕鬆切換後端,以便為模型找到最合適的選擇。

發現2

#Keras 3的效能普遍超過PyTorch的標準實作。

相對於原生PyTorch,Keras 3在吞吐量(步/毫秒)上有明顯的提升。

特別是,在10個測試任務中,有5個的速度提升超過了50%。其中,最高更是達到了290%。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

如果是100%,表示Keras 3的速度是PyTorch的2倍;如果是0%,則表示兩者效能相當

發現3

Keras 3提供一流的「開箱即用」效能。

也就是,所有參與測試的Keras模型都未進行過任何最佳化。相較之下,使用原生PyTorch實作時,通常需要使用者自行進行更多效能最佳化。

除了上面分享的數據,測試中還注意到在HuggingFace Diffusers的StableDiffusion推理功能上,從版本0.25.0升級到0.3.0時,性能提升超過了100% 。

同樣,在HuggingFace Transformers中,Gemma從4.38.1版本升級至4.38.2版本也顯著提高了效能。

這些效能的提升凸顯了HuggingFace在效能優化上的專注與努力。

對於一些手動最佳化較少的模型,如SegmentAnything,則使用了研究作者提供的實作。在這種情況下,與Keras相比,效能差距比大多數其他型號更大。

這表明,Keras能夠提供卓越的開箱即用效能,使用者無需深入了解所有最佳化技巧即可享受到快速的模型運行速度。

發現4

#Keras 3的表現總是優於Keras 2。

例如,SegmentAnything的推理速度提升了驚人的380%,StableDiffusion的訓練處理速度提升了150%以上,BERT的訓練處理速度也提升了100%以上。

這主要是因為Keras 2在某些情況下直接使用了更多的TensorFlow融合操作,而這可能對於XLA的編譯並不是最佳選擇。

值得注意的是,即使只升級到Keras 3並繼續使用TensorFlow後端,也能顯著提升效能。

Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇

結論

框架的表現在很大程度上取決於具體使用的模型。

Keras 3能夠幫助為任務選擇最快的框架,這種選擇幾乎總是超越Keras 2和PyTorch實作。

更為重要的是,Keras 3模型無需進行複雜的底層最佳化,即可提供卓越的開箱即用效能。

以上是Google狂喜:JAX性能超越Pytorch、TensorFlow!或成GPU推理訓練最快選擇的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:51cto.com。如有侵權,請聯絡admin@php.cn刪除