如果所有值都低于test,则替换3D数组行 - 使用索引编制一条numpy行

时间:2017-02-10 01:43:09

标签: python numpy

有没有一种简单的方法可以做到这一点?原始数组的维度为N x M x 3,我想制作一个N x M掩码,如图所示。我认为有一种简单的方法可以用Numpy写这个,但我有脑冻结。

虽然可能有其他方法可以实现这一点,但我试图通过这种方式来保持我使用Numpy索引的技能。

此操作失败,因为img2[test]丢失了原始形状。我是否必须以某种方式.reshape(),或者是否有更简单的方法我不需要明确重述轴的尺寸?

import numpy as np
import matplotlib.pyplot as plt

img  = np.random.random(100*100*3).reshape(100, 100, -1)
img2 = img.copy()

target = np.array([0.5, 0.3, 0.7])

test = np.abs(img - target) < 0.5

img2[test] = np.zeros(3)  # probem, img2[test] looses its shape

plt.figure()
plt.imshow(np.hstack((img, img2)))
plt.show()

File "/Users/yournamehere/Documents/fishing/Eu;er diagram/stackexchange question v00.py", line 12, in <module>
    img2[test] = target
ValueError: NumPy boolean array indexing assignment cannot assign 3 input values to the 11918 output values where the mask is true

2 个答案:

答案 0 :(得分:1)

这是一个展示非常有用的keepdims kwarg

的单线程
np.where(np.logical_and.reduce(test, axis=-1, keepdims=True), [0, 0, 0], img)

答案 1 :(得分:0)

大脑没有冻结。忘记使用.all()将测试合并到每个N x M位置的单个布尔值。一旦纠正了,Numpy就会很好地做到这一点。

enter image description here

import numpy as np
import matplotlib.pyplot as plt

img  = np.random.random(100*100*3).reshape(100, 100, -1)
img2 = img.copy()

target = np.array([0.5, 0.3, 0.7])

test = (np.abs(img - target) < 0.5).all(axis=-1)

img2[test] = np.zeros(3)  # probem, img2[test] looses its shape

plt.figure()
plt.imshow(np.hstack((img, img2)))
plt.show()