首頁 >後端開發 >Python教學 >JAX `vmap` 對於多個參數的意外行為

JAX `vmap` 對於多個參數的意外行為

王林
王林轉載
2024-02-09 09:21:071132瀏覽

JAX `vmap` 对于多个参数的意外行为

問題內容

我發現 jax 中的 vmap 在應用於多個參數時不會如預期執行。例如,考慮下面的函數:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f

對於x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),此函數的輸出形狀為(7, 5 , 3)。但是,對於以下 vmap 版本:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

它輸出此錯誤:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

有人可以解釋一下這個錯誤背後的原因嗎?


正確答案


vmap 的語意是它對一個或多個陣列執行單一批次運算。當您指定in_axes=(none, 0, 0) 時,含義是「同時沿著yz 的前導維度映射」:您看到的錯誤告訴您yy 的前導維度具有不同的大小,因此它們不相容於批次。

您的函數f1 本質上使用廣播來編碼三個批次操作,因此要使用vmap 複製該邏輯,您將需要vmap 的三個應用程式。您可以這樣表達:

@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f

以上是JAX `vmap` 對於多個參數的意外行為的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:stackoverflow.com。如有侵權,請聯絡admin@php.cn刪除