Home >Backend Development >Python Tutorial >JAX `vmap` unexpected behavior with multiple parameters

JAX `vmap` unexpected behavior with multiple parameters

王林
王林forward
2024-02-09 09:21:071106browse

JAX `vmap` 对于多个参数的意外行为

Question content

I discovered that vmap in jax does not behave as expected when applied to multiple parameters. For example, consider the following function:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f

For x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3), the output shape of this function is (7, 5 , 3). However, for the following vmap versions:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

It outputs this error:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

Can someone explain the reason behind this error?


Correct answer


The semantics of vmap is that it performs a single batch operation on one or more arrays. When you specify in_axes=(none, 0, 0), the meaning is "map along the leading dimensions of both y and z": What you see The error tells you that the leading dimensions of y and y have different sizes, so they are not batch compatible.

Your function f1 essentially uses broadcasting to encode three batch operations, so to replicate that logic using vmap you will need vmap's Three applications. You can express it like this:

@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f

The above is the detailed content of JAX `vmap` unexpected behavior with multiple parameters. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:stackoverflow.com. If there is any infringement, please contact admin@php.cn delete