Numpy - 如何根据条件替换元素(或匹配模式)

时间:2014-04-15 07:23:11

标签: python arrays numpy

我有一个numpy数组,说:

>>> a=np.array([[0,1,2],[4,3,6],[9,5,7],[8,9,8]])
>>> a
array([[0, 1, 2],
   [4, 3, 6],
   [9, 5, 7],
   [8, 9, 8]])

我想用最小的元素(逐行)替换第二和第三列元素,除非这两个元素中的一个是< 3。 结果数组应为:

array([[0, 1, 2],# nothing changes since 1 and 2 are <3
   [4, 3, 3], #min(3,6)=3 => 6 changed to 3
   [9, 5, 5], #min(5,7)=5 => 7 changed to 5
   [8, 8, 8]]) #min(9,8)=8 => 9 changed to 8

我知道我可以使用剪辑,例如a[:,1:3].clip(2,6,a[:,1:3]),但

1)剪辑将应用于所有元素,包括那些<3。

2)我不知道如何将剪辑的最小值和最大值设置为每行2个相关元素的最小值。

3 个答案:

答案 0 :(得分:3)

只需使用&gt; =运算符首先选择您感兴趣的内容:

b = a[:, 1:3]  # select the columns
matching = numpy.all(b >= 3, axis=1)  # find rows with all elements matching
b = b[matching, :]  # select rows

现在,您可以使用最小值替换内容,例如:

# find row minimum and convert to a column vector
b[:, :] = b.min(1, keepdims=True)

答案 1 :(得分:1)

我们首先定义row_mask,描述<3条件,然后沿轴应用min imum以找到最小值(row_mask中的行)。

将1dim数组(newaxis imums)广播到作业的2-dim目标时需要min部分。

a=np.array([[0,1,2],[4,3,6],[9,5,7],[8,9,8]])
row_mask = (a[:,0]>=3)
a[row_mask, 1:] = a[row_mask, 1:].min(axis=1)[...,np.newaxis]
a
=> 
array([[0, 1, 2],
       [4, 3, 3],
       [9, 5, 5],
       [8, 8, 8]])

答案 2 :(得分:0)

这是一个班轮:

a[np.where(np.sum(a,axis=1)>3),1:3]=np.min(a[np.where(np.sum(a,axis=1)>3),1:3],axis=2).reshape(1,3,1)

以下是细分:

>>> b = np.where(np.sum(a,axis=1)>3) # finds rows where, in a, row sums are > 3
(array([1, 2, 3]),)
>>> c = a[b,1:3] # the part of a that needs to change
array([[[3, 3],
        [5, 5],
        [8, 8]]])
>>> d = np.min(c,axis=2) # the minimum values in each row (cols 1 and 2)
array([[3, 5, 8]])
>>> e = d.reshape(1,3,1) # adjust shape for broadcast to a
array([[[3],
        [5],
        [8]]])

>>> a[np.where(np.sum(a,axis=1)>3),1:3] = e # set the values in a
>>> a

array([[0, 1, 2],
       [4, 3, 3],
       [9, 5, 5],
       [8, 8, 8]])