更改蒙版np数组中的值

时间:2015-12-07 16:11:41

标签: python arrays numpy

我有一个np数组:

arr = np.array(
[[1,2,3,4,5,6],
[11,12,13,14,15,16],
[21,22,23,24,25,26],
[31,32,33,34,35,36],
[41,42,43,44,45,46]])

以及从arr

中选择“行”的掩码
mask = np.array([False,True,True,True,True])

我试图通过给掩码数组提供相对索引来改变原始数组中的值:

arr[mask1][0,0] = 999

预期产出:

[[  1   2   3   4   5   6]
 [999  12  13  14  15  16]
 [ 21  22  23  24  25  26]
 [ 31  32  33  34  35  36]
 [ 41  42  43  44  45  46]]

然而,问题是arr保持不变。任何解决方法建议?

1 个答案:

答案 0 :(得分:2)

发生了什么:花式与常规索引

这是因为使用布尔数组或索引序列是" fancy"索引。 (" Fancy"是任何不能被表达为切片的东西。)它实际上不是一个"掩盖的数组"这是一个完全用numpy术语(np.ma.masked_array)的单独的东西。

花式索引制作副本。常规索引(即切片)可以查看。视图共享数据,副本不共享。

让我们分解你的表达arr[mask1][0,0] = 999

因为mask1是一个布尔数组,arr[mask1]将返回数据的副本。下一部分将修改该副本,而不是原始数组。换句话说:

# tmp_copy is a copy, not a view, in this case
tmp_copy = arr[mask1]

# tmp_copy is modified, but `arr` is not
tmp_copy[0, 0] = 999 

# Because `tmp_copy` is an intermediate, it will be garbage collected.
# The assignment to 999 effectively disappears
del temp_copy

让我们将其与类似的(在此确切情况下)切片表达式进行对比:arr[1:][0, 0] = 999(这将修改原始arr

# Because we're using a slice, a view will be created instead of a copy
tmp_view = arr[1:]

# Modifying the view will modify the original array as well
tmp_view[0, 0] = 999

# The view can be deleted, but the original has still been modified
del tmp_view

我该如何解决这个问题?

一般来说,你应该避免让自己陷入这种境地。你想要完成的事情通常可以用另一种方式重新演绎。

但是,如果您真的需要,可以通过将奇特的索引转换为您想要修改的特定索引来实现。

例如:

import numpy as np

# Your data...
arr = np.array([[1,2,3,4,5,6],
                [11,12,13,14,15,16],
                [21,22,23,24,25,26],
                [31,32,33,34,35,36],
                [41,42,43,44,45,46]])
mask = np.array([False,True,True,True,True])

# Make a temporary array of the "flat" indices of arr
idx = np.arange(arr.size).reshape(arr.shape)

# Now use this to make your assignment:
arr.flat[idx[mask][0, 0]] = 999

在您的确切情况下,这是过度的(即您可以对arr[1:][0, 0] = 999执行相同操作)。还有很多其他情况可以简化。但是,要获得完全通用的解决方案,我们需要类似于上面示例的内容。

解释变通方法

让我们分解一下这个例子的作用。首先,我们创建一个" flat"与我们的数组形状相同的索引。 (旁注,有关详细信息,请参阅np.unravel_index。)在这种情况下:

In [37]: idx
Out[37]:
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29]])

现在我们可以提取花式索引将提取的索引:

In [38]: idx[mask]
Out[38]:
array([[ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29]])

然后是下一个切片[0,0]

In [39]: idx[mask][0,0]
Out[39]: 6

现在我们有一个" flat"索引回到我们原来的数组。我们可以使用np.unravel_index

将其转换为完整索引
In [40]: np.unravel_index(6, arr.shape)
Out[40]: (1, 0)

...但直接使用arr.flat代替工作更容易:

In [41]: arr.flat[6] = 999

In [42]: arr
Out[42]:
array([[  1,   2,   3,   4,   5,   6],
       [999,  12,  13,  14,  15,  16],
       [ 21,  22,  23,  24,  25,  26],
       [ 31,  32,  33,  34,  35,  36],
       [ 41,  42,  43,  44,  45,  46]])