Home >Backend Development >Python Tutorial >Index multiple elements in multidimensional numpy array
I want to extract the elements of a given multidimensional numpy array using another indexed array. But it doesn't behave as I expected. Here is a simple example:
import numpy as np a = np.random.random((3, 3, 3)) idx = np.asarray([[0, 0, 0], [0, 1, 2]]) b = a[idx] print(b.shape) # expect (2, ), got (2, 3, 3, 3)
Why is this so? How should I modify the code so that b
only contains two elements: a[0, 0, 0]
and a[0, 1, 2]
?
You are looking for numpy advanced indexing
https://www.php.cn/link/2d661a763280f48803f3c9ba8ba0e00b
In your case you need to use idx
on each axis:
a[idx[:,0], idx[:, 1], idx[:, 2]].shape == (2,) # True
The above is the detailed content of Index multiple elements in multidimensional numpy array. For more information, please follow other related articles on the PHP Chinese website!