首頁  >  文章  >  後端開發  >  如何用(N-1)維數組高效率存取N維數組?

如何用(N-1)維數組高效率存取N維數組?

Susan Sarandon
Susan Sarandon原創
2024-10-21 11:57:03138瀏覽

How to Access N-Dimensional Array with (N-1)-Dimensional Array Efficiently?

使用(N-1) 維數組存取N 維數組

給定一個N 維數組a 和一個(N- 1)維數組idx,一個常見的任務是存取idx中索引指定的元素。這對於執行查找最大值或檢索特定值等操作非常有用。

使用高級索引的優雅解決方案

一個優雅的解決方案涉及使用NumPy 的ogrid 函數進行高級索引:

<code class="python">m, n = a.shape[1:]
I, J = np.ogrid[:m, :n]
a_max_values = a[idx, I, J]
b_max_values = b[idx, I, J]</code>

這將建立一個索引網格,並使用它來索引a 和b,從而產生包含相應值的陣列。

函數的一般情況

對於適用於任何指定軸的更通用的解決方案,我們可以定義一個函數:

<code class="python">def argmax_to_max(arr, argmax, axis):
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]</code>

此函數接受一個陣列、其沿指定軸的argmax 以及軸本身。然後它構造一個網格並使用它來提取相應的元素。

使用自訂函數簡化索引

為了進一步簡化索引過程,我們可以建立一個輔助函數來產生索引網格:

<code class="python">def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)</code>

此函數傳回索引元組,可直接用於索引輸入數組:

<code class="python">axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]</code>

以上是如何用(N-1)維數組高效率存取N維數組?的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn