Heim >Backend-Entwicklung >Python-Tutorial >JAX „vmap' unerwartetes Verhalten mit mehreren Parametern

JAX „vmap' unerwartetes Verhalten mit mehreren Parametern

王林
王林nach vorne
2024-02-09 09:21:071073Durchsuche

JAX `vmap` 对于多个参数的意外行为

Frageninhalt

Ich habe festgestellt, dass sich vmap in Jax nicht wie erwartet verhält, wenn es auf mehrere Parameter angewendet wird. Betrachten Sie beispielsweise die folgende Funktion:

def f1(x, y, z):
    f = x[:, none, none] * z[none, none, :] + y[none, :, none]
    return f

für x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3). Allerdings für die folgenden vmap-Versionen:

@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

Es wird dieser Fehler ausgegeben:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

Kann jemand den Grund für diesen Fehler erklären? Die Semantik von


Richtige Antwort


vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy besteht darin, dass eine einzelne Stapeloperation für ein oder mehrere Arrays ausgeführt wird. Wenn Sie in_axes=(none, 0, 0) angeben, bedeutet dies „Abbildung entlang der führenden Dimensionen von y und z“: Sie Der angezeigte Fehler weist darauf hin, dass die führenden Dimensionen von y und y unterschiedliche Größen haben und daher nicht stapelkompatibel sind.

Ihre Funktion f1 verwendet im Wesentlichen Broadcasting, um drei Stapeloperationen zu kodieren. Um diese Logik mit f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap zu replizieren, benötigen Sie also drei Anwendungen von

. Du kannst es so ausdrücken: 🎜
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f

Das obige ist der detaillierte Inhalt vonJAX „vmap' unerwartetes Verhalten mit mehreren Parametern. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:stackoverflow.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen