Numpy使用带有nan的索引列表替换每行的特定列索引

时间:2017-02-10 17:44:02

标签: python numpy

我正在尝试以下方法:

a = np.array([[1,2,3], [4,5,6], [7,8,9]])

print a
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

a[np.arange(len(a)), [1,0,2]] = 20 #--Code1

print a
array([[ 1, 20,  3],
       [20,  5,  6],
       [ 7,  8, 20]])

但是,如果我的索引中包含nan

a[np.arange(len(a)), [1,np.nan,2]] = 20  #--Code2

错误了。

我想要做的是,如果索引中存在nan,请不要更改任何内容。

即。我想在上面实现Code2,以便我可以获得以下内容:

    array([[ 1, 20,  3],
           [4,  5,  6],
           [ 7,  8, 20]])

1 个答案:

答案 0 :(得分:1)

使用masking -

m = ~np.isnan(idx) # Mask of non-NaNs
row = np.arange(a.shape[0])[m]
col = idx[m].astype(int)
a[row, col] = 20

其中,idx是索引数组。

示例运行 -

In [161]: a = np.array([[1,2,3], [4,5,6], [7,8,9]])

In [162]: idx = np.array([1,np.nan,2])

In [163]: m = ~np.isnan(idx) # Mask of non-NaNs
     ...: row = np.arange(a.shape[0])[m]
     ...: col = idx[m].astype(int)
     ...: a[row, col] = 20
     ...: 

In [164]: a
Out[164]: 
array([[ 1, 20,  3],
       [ 4,  5,  6],
       [ 7,  8, 20]])