ホームページ > 記事 > テクノロジー周辺機器 > PyTorch チームは、元の実装より 8 倍の速さで「すべてを分割」モデルを再実装しました。
今年の初めから現在まで、生成 AI は急速に発展しました。しかし、多くの場合、特に PyTorch を使用する場合、生成 AI のトレーニングや推論などをどのように高速化するかという難しい問題に直面しなければなりません。
この記事では、PyTorch チームの研究者がソリューションを提供します。この記事では、純粋なネイティブ PyTorch を使用して生成 AI モデルを高速化する方法に焦点を当てており、PyTorch の新しい機能とそれらを組み合わせる方法の実践例も紹介しています。
結果はどうなりましたか? PyTorch チームは、Meta の「Split Everything」(SAM) モデルを書き直した結果、精度を損なうことなく元の実装よりも 8 倍高速なコードが得られ、すべてネイティブ PyTorch を使用して最適化されたと述べました。
# ブログ アドレス: https://pytorch.org/blog/accelerated-generative-ai/
この記事を読むと、次の理解が得られます:
PyTorch のネイティブ機能によってもたらされるスループットの向上とメモリ オーバーヘッドの削減。
この研究の詳細については、Meta が提案した SAM を参照してください。詳細な記事は「CV はもう存在しない? Meta releases "Split Everything" AI model, CV might usher in GPT-3 moment」
# でご覧いただけます。 #次に、パフォーマンス分析、ボトルネックの特定、およびこれらの新機能を PyTorch に統合して SAM が直面する問題を解決する方法など、SAM の最適化プロセスを紹介します。さらに、torch.compile、SDPA、Triton カーネル、Nested Tensor、半構造化スパース性 (半構造化スパース性) など、PyTorch のいくつかの新機能も紹介します。
コンテンツ ステップステップごとに詳しく見ていき、この記事では最後に高速バージョンの SAM を紹介します。興味のある読者は、GitHub からダウンロードできます。さらに、これらのデータは Perfetto UI を使用して視覚化され、PyTorch のさまざまな機能のアプリケーション価値を実証しました。
GitHub アドレス: https://github.com/pytorch-labs/segmentこのプロジェクトのコードは、-anything-fast にあります。
研究では、次のことが指摘されています。 SAM ベースライン データ型は float32 dtype、バッチ サイズは 1、PyTorch Profiler を使用してコア トレースを表示した結果は次のとおりです。
この記事では、SAM には最適化できる箇所が 2 つあることがわかりました。
1 つ目は、aten::index への長い呼び出しです。これは、テンソル インデックス操作 ([] など) 生成された基礎的な呼び出しによって引き起こされます。ただし、GPU が aten::index に費やす実際の時間は比較的短く、その理由は、2 つのコアを起動するプロセス中に、aten::index が 2 つのコア間の cudaStreamSynchronize をブロックするためです。これは、2 番目のコアが起動されるまで、CPU が GPU の処理が完了するのを待つことを意味します。したがって、SAM を最適化するには、アイドル時間を引き起こす GPU 同期のブロックを排除するよう努めるべきであるとこの文書では考えています。
2 番目の問題は、SAM が行列の乗算 (図に示す濃い緑色の部分) に多くの GPU 時間を費やしていることです。これは、Transformers モデルでは非常に一般的です。行列乗算における SAM モデルの GPU 時間を削減できれば、SAM
の速度を大幅に向上させることができます。 次に、SAM のスループット (img/s ) を取得します。ベースラインを確立するためのメモリ オーバーヘッド (GiB)。次に、最適化プロセスがあります。
#書き直す必要がある文は次のとおりです。 Bfloat16 半精度 (プラス) GPU 同期とバッチ処理)
#上記の問題を解決するために、つまり行列の乗算に必要な時間を短縮するために、この記事では bfloat16 を取り上げます。 bfloat16 は一般的に使用される半精度型であり、各パラメータとアクティベーションの精度を下げることで、計算時間とメモリを大幅に節約できます
##fill type を bfloat16 に置き換えます
さらに、この記事では、GPU 同期を削除するために最適化できる場所が 2 つあることがわかりました
具体的には、上の図、研究結果 SAM 画像エンコーダには、座標スケーラーとして機能する 2 つの変数 q_coords と k_coords があり、これらの変数は CPU 上で割り当てられて処理されます。ただし、これらの変数が rel_pos_resize でのインデックス付けに使用されると、インデックス付け操作によってこれらの変数が GPU に自動的に移動され、GPU 同期の問題が発生します。この問題を解決するために、調査では、上記のように、torch.where 関数を使用してこの部分を書き換えることができることが指摘されました。
Core Tracking
これらの変更を適用した後、特に小さなバッチ (ここでは 1) の場合に、個々のカーネル呼び出し間に顕著な時間のギャップがあることに気付きました。この現象をより深く理解するために、バッチ サイズ 8
で SAM 推論のパフォーマンス分析を開始します。カーネルごとに費やされる時間を分析すると、SAM の GPU 時間のほとんどが要素ごとのカーネルとソフトマックス操作に費やされていることがわかります。
#行列乗算のオーバーヘッドが比較的小さいことがわかります。たくさん。
GPU 同期と bfloat16 最適化を組み合わせると、SAM パフォーマンスが 3 倍向上します。
SAMの研究中に発見された多くの小さな操作が実行されました。研究者は、コンパイラを使用してこれらの操作を統合することが非常に有益であると考えているため、PyTorch は torch.compile に対して次の最適化を行いました。
これらの最適化により、この研究では GPU グローバル メモリの往復回数が削減され、推論が高速化されました。 SAM の画像エンコーダで torch.compile を試すことができるようになりました。パフォーマンスを最大化するために、この記事ではいくつかの高度なコンパイル手法を使用します。
#コアトラッキング
## によると結果を見ると、torch.compile のパフォーマンスは非常に良好です。softmax が時間の大部分を占め、その後にさまざまな GEMM が続いていることがわかります。亜種。次の測定値は、バッチ サイズ 8 以上のものです。
SDPA:scaled_dot_product_attention
コア トラッキング
お客様が利用できるようになりました。メモリ効率の高いアテンション カーネルは、GPU で多くの計算時間を消費します:
PyTorch のネイティブのscaled_dot_product_attentionを使用すると、バッチ サイズを大幅に増やすことができます。以下のグラフは、バッチ サイズ 32 以上の変化を示しています。
次に、Triton、NestedTensor、バッチ Predict_torch、int8 量子化、半構造化 (2:4) スパース性に関する研究が行われました。操作
たとえば、この記事ではカスタム位置 Triton カーネルを使用し、バッチ サイズ 32 での測定結果を観察します。
#Nested Tensor テクノロジを採用し、バッチ サイズを 32 以上に調整します
##量子化を追加すると、バッチ サイズが 32 以上になると測定結果が異なります。
################記事の最後は、半構造化されたスパーシティです。この研究は、行列の乗算が依然として直面する必要のあるボトルネックであることを示しています。解決策は、スパース化を使用して行列乗算を近似することです。スパース行列 (つまり、値をゼロにする) により、重みとアクティベーション テンソルを格納するために使用できるビットが少なくなります。テンソル内のどの重みがゼロに設定されるかを設定するプロセスは、枝刈りと呼ばれます。より小さい重みを取り除くと、精度を大幅に損なうことなくモデルのサイズを削減できる可能性があります。 ######プルーニングには、完全に構造化されていないものから高度に構造化されたものまで、さまざまな方法があります。非構造化枝刈りは理論的には精度に最小限の影響を与えますが、スパースの場合、GPU は大規模な密行列の乗算を実行するときに非常に効率的であるにもかかわらず、大幅なパフォーマンスの低下を経験する可能性があります。 PyTorch で最近サポートされたプルーニング手法の 1 つは、バランスを見つけることを目的とした半構造化 (または 2:4) スパース性です。この疎な格納方法は、元のテンソルを 50% 削減しながら、高密度のテンソル出力を生成します。説明については、以下の図を参照してください。
このスパース ストレージ形式と関連する高速カーネルを使用するには、次に行うことは次のとおりです。重みを切り詰めます。この記事では、スパース度 2:4 でプルーニングするための最小の 2 つの重みを選択します。重みをデフォルトの PyTorch (「ストライド」) レイアウトからこの新しい半構造化スパース レイアウトに変更するのは簡単です。 apply_sparse (モデル) を実装するには、32 行の Python コードのみが必要です。
スパース性が 2:4 の場合、vit_b とバッチ サイズ 32 の SAM ピーク パフォーマンス
最後に、この記事の概要は次のとおりです。 この記事では、これまでの PyTorch での実装について説明します。正式にリリースされた一連の新機能を利用して、何でもセグメント化する方法。この記事では、正確さを失うことなく、純粋な PyTorch で元の SAM を書き直しています。
興味のある読者は、元の SAM を確認してください。詳細についてはブログをご覧ください
以上がPyTorch チームは、元の実装より 8 倍の速さで「すべてを分割」モデルを再実装しました。の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。