我有一批形状为(N, C, H, W)
的图像,其中N代表图像数量,C代表通道数量,H,W代表高度和宽度。
每个图像都有2个通道,其中某些像素的值为[-1 , -1]
。
如何在不使用for循环的情况下查找批处理中这些像素的位置,因为它非常慢。
答案 0 :(得分:1)
使用numpy.where
:
# creating test data
test = np.zeros((5,2,3,3))
test[3,:,2,1] = [-1.,-1.]
value = -np.ones((1.,2.,1.,1.)) # this is the value you are looking for
np.where(test == value)
# this returns: (array([3, 3], dtype=int64),
# array([0, 1], dtype=int64),
# array([2, 2], dtype=int64),
# array([1, 1], dtype=int64))
编辑:
要获取相应的蒙版,只需不使用where
:
test == value
答案 1 :(得分:0)
您可以使用numpy.where
。一个简单的例子:
x = np.random.randn(4,2,10,10)
x[0,1,2,3] = 1
x[0,1,4,5] = 1
np.where(x==1)
(array([0,0],dtype = int64),array([1,1],dtype = int64),array([2, 4],dtype = int64),array([3,5],dtype = int64))