我發現 jax 中的 vmap
在應用於多個參數時不會如預期執行。例如,考慮下面的函數:
def f1(x, y, z): f = x[:, none, none] * z[none, none, :] + y[none, :, none] return f
對於x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
,此函數的輸出形狀為(7, 5 , 3)
。但是,對於以下 vmap 版本:
@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2)) def f2(x, y, z): f = x*z + y return f
它輸出此錯誤:
ValueError: vmap got inconsistent sizes for array axes to be mapped: * one axis had size 5: axis 0 of argument y of type int32[5]; * one axis had size 3: axis 0 of argument z of type int32[3]
有人可以解釋一下這個錯誤背後的原因嗎?
vmap
的語意是它對一個或多個陣列執行單一批次運算。當您指定in_axes=(none, 0, 0)
時,含義是「同時沿著y
和z
的前導維度映射」:您看到的錯誤告訴您y
和y
的前導維度具有不同的大小,因此它們不相容於批次。
您的函數f1
本質上使用廣播來編碼三個批次操作,因此要使用vmap
複製該邏輯,您將需要vmap
的三個應用程式。您可以這樣表達:
@partial(vmap, in_axes=(0, None, None)) @partial(vmap, in_axes=(None, 0, None)) @partial(vmap, in_axes=(None, None, 0)) def f2(x, y, z): f = x*z + y return f
以上是JAX `vmap` 對於多個參數的意外行為的詳細內容。更多資訊請關注PHP中文網其他相關文章!