首页 >后端开发 >Python教程 >如何使用低维索引数组高效索引 N 维数组?

如何使用低维索引数组高效索引 N 维数组?

Mary-Kate Olsen
Mary-Kate Olsen原创
2024-10-21 13:09:021077浏览

How to Efficiently Index N-Dimensional Arrays with Lower-Dimensional Index Arrays?

使用 (N-1) 维数组索引 N 维数组

使用 (N) 访问 N 维数组-1) 维数组在寻找沿特定维度对齐的值时提出了挑战。使用 np.argmax 的传统方法可能不够。

高级索引方法

可以通过使用 np.ogrid 的高级索引来实现优雅的索引。对于 3D 数组 a 及其沿第一维的 argmax,idx:

import numpy as np

a = np.random.random_sample((3, 4, 4))
idx = np.argmax(a, axis=0)

m, n = a.shape[1:]
I, J = np.ogrid[:m, :n]
a_max_values = a[idx, I, J]

此方法创建一个网格,可有效地将索引数组扩展到原始数组的完整维度。

任意维度的泛化

对于更泛化的解决方案,可以定义 argmax_to_max() 函数:

def argmax_to_max(arr, argmax, axis):
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]

该函数采用原始数组及其 argmax,和所需的轴并返回相应的最大值。

通用索引的替代方法

用于使用 (N-1) 维索引任何 N 维数组数组, all_idx() 函数是一个更简化的解决方案:

def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)

使用此函数,可以使用 idx 沿 axis 索引数组 a 来完成:

axis = 0
a_max_values = a[all_idx(idx, axis=axis)]

以上是如何使用低维索引数组高效索引 N 维数组?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn