如何加快Numpy标签贴图中标签张量的创建?

时间:2016-07-19 01:44:17

标签: python numpy

给定尺寸为WXH的标签贴图,其中每个元素可以从{0,..,K-1}获取值,我想输出尺寸为KXW x H的标签张量,其中K&#39th map中的每个元素仅当labelmap中的相应值为K时才为1.目前我的实现使用两个for循环并且非常慢。

p_label = Labelmap with one channel

label = np.zeros((K,p_label.shape[0], p_label.shape[1]))
for i in xrange(p_label.shape[0]):
      for j in xrange(p_label.shape[1]):
           label[p_label[i,j],i,j] = 1

使用广播是否有更好的方法在Numpy中执行此操作?

2 个答案:

答案 0 :(得分:2)

您可以将==运算符用于广播。

例如,

In [19]: W = 5

In [20]: H = 8

In [21]: K = 10

为示例创建p_label

In [22]: p_label = np.random.randint(0, K, size=(W, H))

kvals只是一个包含[0,1,...,K-1]的数组:

In [23]: kvals = np.arange(K)

kvals.reshape(-1, 1, 1)kvals转换为具有形状(K,1,1)的数组。使用==p_label进行比较。广播适用,因此比较的结果具有形状(K,W,H)。它是您想要的值的布尔数组。 .astype(int)将结果转换为整数数组。 (如果布尔数组适合你,你可以删除它。)

In [24]: label = (p_label == kvals.reshape(-1, 1, 1)).astype(int)

这是原始p_label。请注意,例如,值0的位置:

In [25]: p_label
Out[25]: 
array([[3, 3, 2, 6, 2, 2, 9, 3],
       [1, 8, 1, 1, 4, 3, 7, 8],
       [5, 9, 1, 0, 7, 2, 8, 0],
       [1, 3, 5, 4, 6, 0, 9, 5],
       [5, 7, 2, 0, 6, 4, 5, 3]])
label[0]p_label的位置,

0为1。

In [26]: label[0]
Out[26]: 
array([[0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0]])

答案 1 :(得分:1)

Label[p_label, np.arange(p_label.shape[0])[:,None], np.arange(p_label.shape[1])] = 1

3个索引数组相互广播。

==============================

lmap = np.arange(12).reshape(3,4)
lbl = np.zeros((12,3,4),int)
lbl[lmap,np.arange(3)[:,None],np.arange(4)] = 1

In [5]: lbl
Out[5]: 
array([[[1, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       ...
       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 1]]])