python:np.where()和广播

时间:2018-03-15 06:42:48

标签: python numpy numpy-broadcasting

有人可以帮我理解np.where()函数下面的广播是如何工作的吗?

x = np.arange(9.).reshape(3, 3)
np.where(x < 5, x, -1)      # Note: broadcasting.
array([[ 0.,  1.,  2.],
   [ 3.,  4., -1.],
   [-1., -1., -1.]])

1 个答案:

答案 0 :(得分:0)

让我们看看各个部分

x = np.arange(9).reshape(3, 3)

>>> x
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

请注意x < 5生成一组布尔值:

>>> x < 5
array([[ True,  True,  True],
       [ True,  True, False],
       [False, False, False]]

将其插入np.where

>>> np.where(x < 5, x, -1)
array([[ 0,  1,  2],
       [ 3,  4, -1],
       [-1, -1, -1]])

请注意,-1已广播以匹配x < 5的尺寸:

array([[-1, -1, -1],
       [-1, -1, -1],
       [-1, -1, -1]])

由于x已经具有正确的尺寸,因此不需要任何广播。