如何对指定为矩阵行的特定索引集的平均值进行矢量计算?

时间:2019-01-16 15:23:09

标签: python numpy vectorization pytorch

我在pytorch中向量化某些代码时遇到问题。 numpy解决方案也有帮助,但是pytorch解决方案会更好。 我将交替使用arrayTensor

我面临的问题是:

给出大小为(n,x)的2D浮点数组X和大小为(n,n)的布尔2D数组A,计算X中各行的平均值由A中的行索引。 问题在于A中的行包含可变数量的True索引。

示例(numpy):

import numpy as np
A = np.array([[0, 1, 0, 0, 0, 0],
              [1, 0, 1, 0, 0, 0],
              [0, 1, 0, 0, 0, 0],
              [0, 0, 0, 0, 1, 0],
              [0, 0, 0, 1, 0, 0],
              [0, 1, 1, 1, 0, 0]])
X = np.arange(6 * 3, dtype=np.float32).reshape(6, 3)

# Compute the mean in numpy with a for loop
means_np = np.array([X[A.astype(np.bool)[i]].mean(axis=0) for i in np.arange(len(A)])

因此,此示例有效,但此公式存在3个问题:

  1. 对于较大的AX,for循环很慢。我需要遍历几万个索引。

  2. A[i]可能不包含任何True索引。结果为np.mean(np.array([])),即NaN。我希望它改为0。

  3. 在pytorch中以这种方式执行此操作会导致反向传播通过此函数的反向传递过程中出现SIGFPE(浮点错误)。原因是什么也没选择。

我现在使用的解决方法是(另请参见下面的代码):

  • A的对角元素设置为True,以便始终至少有一个元素可供选择
  • 所有选定元素的总和,从该总和中减去X中的值(保证对角线开头应为False),然后除以True个元素的数量-每行至少1个钳位1个。

这有效,在pytorch中是可区分的,不会产生NaN,但是我仍然需要在所有索引上循环。 如何摆脱这个循环?

这是我当前的pytorch代码:

 import torch
 A = torch.from_numpy(A).bytes()
 X = torch.from_numpy(X)
 A[np.diag_indices(len(A)] = 1  # Set the diagonal to 1
 means = [(X[A[i]].sum(dim=0) - X[i]) / torch.clamp(A[i].sum() - 1, min=1.)  # Compute the mean safely
          for i in range(len(A))]  # Get rid of the loop somehow
 means = torch.stack(means)

我不介意您的版本看起来是否完全不同,只要它可以区分并产生相同的结果即可。

1 个答案:

答案 0 :(得分:1)

我们可以利用matrix-multiplication-

c = A.sum(1,keepdims=True)
means_np = np.where(c==0,0,A.dot(X)/c)

我们可以通过将A dtype转换为float32 dtype来进一步优化它,如果还不是这样,并且在那里精度还可以的话,如下所示-

In [57]: np.random.seed(0)

In [58]: A = np.random.randint(0,2,(1000,1000))

In [59]: X = np.random.rand(1000,1000).astype(np.float32)

In [60]: %timeit A.dot(X)
10 loops, best of 3: 27 ms per loop

In [61]: %timeit A.astype(np.float32).dot(X)
100 loops, best of 3: 10.2 ms per loop

In [62]: np.allclose(A.dot(X), A.astype(np.float32).dot(X))
Out[62]: True

因此,使用A.astype(np.float32).dot(X)替换A.dot(X)

或者,要解决行总和为zero,并且要求我们使用np.where的情况,我们可以将任何非零值(例如1)分配给c,然后将其除以-

c = A.sum(1,keepdims=True)
c[c==0] = 1
means_np = A.dot(X)/c

这还可以避免在那些零行总和情况下从np.where得到警告的警告。