首页 >后端开发 >Python教程 >如何高效查找 NumPy 数组中多个最大值的索引?

如何高效查找 NumPy 数组中多个最大值的索引?

Patricia Arquette
Patricia Arquette原创
2024-12-02 13:24:151008浏览

How Can I Efficiently Find the Indices of Multiple Maximum Values in a NumPy Array?

检索 NumPy 数组中多个最大值的索引

NumPy 提供了一个方便的 np.argmax 函数来检索 NumPy 数组中最大值的索引一个数组。但是,如果您需要找到前 N 个最大值的索引怎么办?

解决方案

最近的 NumPy 版本(1.8 及更高版本)为此引入了 argpartition 函数目的。要获取前 N 个元素的索引,请按照以下步骤操作:

import numpy as np

# Original array
a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])

# Find indices of top N elements (N = 4 in this case)
ind = np.argpartition(a, -4)[-4:]

# Extract top N elements
top4 = a[ind]

# Print indices and top N elements
print("Indices:", ind)
print("Top 4 elements:", top4)

说明

np.argpartition 对数组进行部分排序,将其分为两个子数组:第一个子数组包含前 N 个元素(在本例中为最大的 4 个元素),第二个子数组包含其余元素。返回的数组 ind 包含第一个子数组中元素的索引。

此示例中的输出将是:

Indices: [1 5 8 0]
Top 4 elements: [4 9 6 9]

优化

如果还需要排序索引,可以单独排序:

sorted_ind = ind[np.argsort(a[ind])]

这一步需要O(k log k) 时间,其中 k 是要检索的顶部元素的数量。总的来说,这种方法的时间复杂度为 O(n k log k),对于大型数组和中等 k 值非常有效。

以上是如何高效查找 NumPy 数组中多个最大值的索引?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn