Rumah >pembangunan bahagian belakang >Tutorial Python >JAX `vmap` tingkah laku tidak dijangka dengan berbilang parameter

JAX `vmap` tingkah laku tidak dijangka dengan berbilang parameter

王林
王林ke hadapan
2024-02-09 09:21:071107semak imbas

JAX `vmap` 对于多个参数的意外行为

Kandungan soalan

Saya mendapati bahawa vmap dalam jax tidak berkelakuan seperti yang diharapkan apabila digunakan pada berbilang parameter. Sebagai contoh, pertimbangkan fungsi berikut:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f

untuk x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3). Walau bagaimanapun, untuk versi vmap berikut:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

Ia mengeluarkan ralat ini:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

Bolehkah seseorang menerangkan sebab di sebalik kesilapan ini? Semantik


Jawapan Betul


vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy ialah ia melakukan operasi kumpulan tunggal pada satu atau lebih tatasusunan. Apabila anda menentukan in_axes=(none, 0, 0), maksudnya ialah "peta di sepanjang dimensi utama kedua-dua y dan z": anda Ralat yang anda lihat memberitahu anda bahawa dimensi utama y dan y mempunyai saiz yang berbeza, jadi ia tidak serasi kelompok.

Fungsi f1 anda pada asasnya menggunakan penyiaran untuk mengekod tiga operasi kelompok, jadi untuk meniru logik itu menggunakan f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap anda memerlukan tiga aplikasi

. Anda boleh menyatakannya seperti ini: 🎜
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f

Atas ialah kandungan terperinci JAX `vmap` tingkah laku tidak dijangka dengan berbilang parameter. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:stackoverflow.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam