首页  >  文章  >  后端开发  >  如何在NumPy的where函数中有效组合多个条件?

如何在NumPy的where函数中有效组合多个条件?

Mary-Kate Olsen
Mary-Kate Olsen原创
2024-10-27 00:59:30262浏览

How to Effectively Combine Multiple Conditions in NumPy's where Function?

NumPy 的 where 函数中的多个条件

NumPy 中,where() 函数常用于条件选择。在处理多个条件时,了解如何有效地组合它们以获得所需的结果非常重要。

考虑一个示例,我们想要选择指定范围内的距离。以下代码尝试执行此操作:

<code class="python">dists[(np.where(dists >= r)) and (np.where(dists <= r + dr))]

但是,这会产生意外结果,仅选择第二个条件 (np.where(dists <= r dr)) 内的距离。

修复代码

要解决这个问题,我们需要了解 np.where() 返回满足条件的元素的索引,而不是布尔数组。因此,组合多个 np.where() 调用的结果不会产生布尔数组。

我们可以使用元素布尔运算符来执行所需的条件选择。以下是两种正确的实现方法:

选项 1:组合条件

<code class="python">dists[(dists >= r) & (dists <= r + dr)]

& 运算符执行元素与,生成布尔数组。然后我们可以用它来索引原始数组 dists。

选项 2:使用中间变量

<code class="python">mask1 = dists >= r
mask2 = dists <= r + dr
dists[(mask1) & (mask2)]

通过为每个条件创建临时变量,我们可以检查两个条件并使用 & 运算符组合它们以创建布尔数组。

为什么原始代码不起作用

原始代码不起作用,因为 np .where() 返回索引列表,而不是布尔数组。组合两个索引列表不会给出所需的结果。

例如:

<code class="python">dists = np.arange(0, 10, 0.5)
r = 5
dr = 1

mask1 = np.where(dists >= r)
mask2 = np.where(dists <= r + dr)

print(mask1 and mask2)
# Outputs: (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]),)</code>

如您所见,结果数组不是指示哪些元素满足这两个条件的布尔数组。

以上是如何在NumPy的where函数中有效组合多个条件?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn