ホームページ >バックエンド開発 >Python チュートリアル >JAX `vmap` が複数のパラメーターを使用した場合の予期しない動作

JAX `vmap` が複数のパラメーターを使用した場合の予期しない動作

王林
王林転載
2024-02-09 09:21:071066ブラウズ

JAX `vmap` 对于多个参数的意外行为

質問内容

jax の vmap が複数のパラメータに適用されると期待どおりに動作しないことがわかりました。たとえば、次の関数について考えてみましょう:

リーリー

x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3) の場合、この関数の出力形状は (7, 5、3)。ただし、次の vmap バージョンの場合:

リーリー

次のエラーが出力されます:

リーリー

誰かがこのエラーの背後にある理由を説明できますか?


正解


vmapのセマンティクスは、1つ以上の配列に対して単一のバッチ操作を実行することです。 in_axes=(none, 0, 0) を指定すると、「yz の両方の先頭の次元に沿ってマップ」という意味になります。このエラーは、yy の先頭のディメンションのサイズが異なるため、バッチ互換性がないことを示しています。

関数 f1 は基本的にブロードキャストを使用して 3 つのバッチ操作をエンコードしているため、vmap を使用してそのロジックを複製するには、vmap の 3 つのアプリケーションが必要になります。次のように表現できます:

リーリー

以上がJAX `vmap` が複数のパラメーターを使用した場合の予期しない動作の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事はstackoverflow.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。