Maison >développement back-end >Tutoriel Python >Comportement inattendu de JAX `vmap` avec plusieurs paramètres

Comportement inattendu de JAX `vmap` avec plusieurs paramètres

王林
王林avant
2024-02-09 09:21:071067parcourir

JAX `vmap` 对于多个参数的意外行为

Contenu de la question

J'ai découvert que vmap dans jax ne se comporte pas comme prévu lorsqu'il est appliqué à plusieurs paramètres. Par exemple, considérons la fonction suivante :

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f

pour x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3). Cependant, pour les versions de vmap suivantes :

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

Il affiche cette erreur :

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

Quelqu'un peut-il expliquer la raison de cette erreur ? La sémantique de


Correct Answer


vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy est qu'elle effectue une opération par lots unique sur un ou plusieurs tableaux. Lorsque vous spécifiez in_axes=(none, 0, 0), la signification est « mapper le long des dimensions principales de y et z » : L'erreur que vous voyez vous indique que les dimensions principales de y et y ont des tailles différentes, elles ne sont donc pas compatibles par lots.

Votre fonction f1 utilise essentiellement la diffusion pour coder trois opérations par lots, donc pour reproduire cette logique en utilisant f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap vous auriez besoin de trois applications de

. Vous pouvez l'exprimer ainsi : 🎜
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Cet article est reproduit dans:. en cas de violation, veuillez contacter admin@php.cn Supprimer