具有动态批次大小的批次矩阵乘法

时间:2019-11-22 12:00:31

标签: machine-learning pytorch matrix-multiplication transformation-matrix

我正在编写一个使用批矩阵乘法的过程,也许不在通用设置下。我正在考虑以下输入:

# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3 
X = torch.tensor([[1,1,1],[1,1,1],
                  [2,2,2],[2,2,2],[2,2,2],
                  [3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
                  [[2,2,2],[2,2,2],[2,2,2]],
                  [[3,3,3],[3,3,3],[3,3,3]]])

所需的输出是:

out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1)                # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]],            # obj0 mult with A_0
                      [[2,2,2]],[[2,2,2]],[[2,2,2]],  # obj1 mult with A_1
                      [[3,3,3]],])                    # obj2 mult with A_2

实际上在pytorch中非常方便,只需一行!

在这里,我想改进此过程。请注意,我使用 A [idx] 为每个点复制一个矩阵A_i,因此我可以在此处使用torch.bmm()函数(1个点<-> 1个矩阵)。 Afaik,将需要为 A [idx] 的中间表示分配内存。通常,如果我的数据批处理中包含BN对象,则A [idx] =(B1 + ... + BN)* 3 * 3的大小可能会很大。

因此,我想知道是否可以避免矩阵A_i的复制。

我找到了有关Batch Mat的最常见问题。多仅假定固定的批次大小。 Here与我的问题相同,并提供了张量流解决方案。但是,该解决方案是使用tf.tile()实现的,它也可以复制矩阵。


总而言之,我的问题是关于批矩阵乘法,同时实现:

- dynamic batch size
    - input shape: (B1+...+BN) x 3
    - index shape: (B1+...+BN)
- memory efficiency
    - probably w/out massive replication of matrix

我在这里使用pytorch,但我也接受其他实现。如果可以提高存储效率,我也接受在其他结构中表示输入(例如,相乘矩阵A)。

0 个答案:

没有答案