“带有广播和布尔掩码的花式索引”如何工作?

时间:2020-01-27 14:39:14

标签: python arrays numpy masking array-broadcasting

我在Jake Vanderplas的Data Science Handbook中遇到了这段代码。对我来说,将广播与花式索引一起使用的概念尚不清楚。请解释。

In[5]: X = np.arange(12).reshape((3, 4))
 X
Out[5]: array([[ 0, 1, 2, 3],
 [ 4, 5, 6, 7],
 [ 8, 9, 10, 11]])

In[6]: row = np.array([0, 1, 2])
 col = np.array([2, 1, 3])

In[7]: X[row[:, np.newaxis], col]
Out[7]: array([[ 2, 1, 3],
               [ 6, 5, 7],
              [10, 9, 11]])

它说:“这里,每个行值都与每个列向量匹配,就像我们在广播算术运算时所看到的那样。例如:”

In[8]: row[:, np.newaxis] * col
Out[8]: array([[0, 0, 0],
               [2, 1, 3],
               [4, 2, 6]])

2 个答案:

答案 0 :(得分:0)

如果您使用整数数组来索引另一个数组 您基本上可以遍历给定的索引,并沿要索引的轴选择相应的元素(可能仍然是数组),并将它们堆叠在一起。

arr55 = np.arange(25).reshape((5, 5))
# array([[ 0,  1,  2,  3,  4],
#        [ 5,  6,  7,  8,  9],
#        [10, 11, 12, 13, 14],
#        [15, 16, 17, 18, 19],
#        [20, 21, 22, 23, 24]])

arr53 = arr55[:, [3, 3, 4]]  
# pick the elements at (arr[:, 3], arr[:, 3], arr[:, 4])
# array([[ 3,  3,  4],
#        [ 8,  8,  9],
#        [13, 13, 14],
#        [18, 18, 19],
#        [23, 23, 24]])

因此,如果您使用长度为(m, n)(或长度为k)的行(或col)索引来索引l数组,则结果形状为:

A_nm[row, :] -> A_km
A_nm[:, col] -> A_nl

但是,如果您使用两个数组rowcol来索引一个数组 您可以同时遍历两个索引,并将各个位置的元素(可能仍然是数组)堆叠在一起。 在这里rowcol的长度必须相同。

A_nm[row, col] -> A_k
array([ 3, 13, 24])

arr3 = arr55[[0, 2, 4], [3, 3, 4]]  
# pick the element at (arr[0, 3], arr[2, 3], arr[4, 4])

现在终于可以提您的问题了:在对数组建立索引时可以使用广播。有时不希望仅元素

(arr[0, 3], arr[2, 3], arr[4, 4])
选择了

,而是选择了扩展版本:

(arr[0, [3, 3, 4]], arr[2, [3, 3, 4]], arr[4, [3, 3, 4]])
# each row value is matched with each column vector

此匹配/广播与其他算术运算完全相同。 但是,此处的示例在某种意义上可能是不好的,因为所示乘法的结果对索引而言并不重要。 这里的重点是组合和最终的形状:

row * col  
# performs a element wise multiplication resulting in 3 
numbers
row[:, np.newaxis] * col 
# performs a multiplication where each row value is *matched* with each column vector

该示例希望强调rowcol的这种匹配。

我们可以看看并尝试各种可能性:

n = 3
m = 4
X = np.arange(n*m).reshape((n, m))
row = np.array([0, 1, 2])  # k = 3
col = np.array([2, 1, 3])  # l = 3

X[row, :]  # A_nm[row, :] -> A_km
# array([[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]])

X[:, col]  # A_nm[:, col] -> A_nl
# array([[ 2,  1,  3],
#        [ 6,  5,  7],
#        [10,  9, 11]])

X[row, col]  # A_nm[row, col] -> A_l == A_k
# array([ 2,  5, 11]

X[row, :][:, col]  # A_nm[row, :][:, col] -> A_km[:, col] -> A_kl 
# == X[:, col][row, :]
# == X[row[:, np.newaxis], col]  # A_nm[row[:, np.newaxis], col] -> A_kl 
# array([[ 2,  1,  3],
#        [ 6,  5,  7],
#        [10,  9, 11]])

X[row, col[:, np.newaxis]]
# == X[row[:, np.newaxis], col].T
# array([[ 2,  6, 10],
#        [ 1,  5,  9],
#        [ 3,  7, 11]])

答案 1 :(得分:0)

我来这里是为了寻找这个问题的答案,hpaulj 的评论帮助了我。我将对其进行扩展。

在以下代码段中,

import numpy as np
X = np.arange(12).reshape((3, 4))
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
Y = X[row.reshape(-1, 1), col]

我们传递给X索引正在广播。

下面的代码遵循 numpy 广播规则但使用更多的内存,完成相同的切片:

# Make the row and column indices 'conformable'
R = np.repeat(row.reshape(-1, 1), 3, axis=1)  # repeat row index across columns
C = np.repeat(col.reshape(1, -1), 3, axis=0)  # repeat column index across rows
Y = X[R, C]  # Y[i, j] = X[R[i, j], C[i, j]]