我正在编写一个使用批矩阵乘法的过程,也许不在通用设置下。我正在考虑以下输入:
# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3
X = torch.tensor([[1,1,1],[1,1,1],
[2,2,2],[2,2,2],[2,2,2],
[3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
[[2,2,2],[2,2,2],[2,2,2]],
[[3,3,3],[3,3,3],[3,3,3]]])
所需的输出是:
out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1) # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]], # obj0 mult with A_0
[[2,2,2]],[[2,2,2]],[[2,2,2]], # obj1 mult with A_1
[[3,3,3]],]) # obj2 mult with A_2
实际上在pytorch中非常方便,只需一行!
在这里,我想改进此过程。请注意,我使用 A [idx] 为每个点复制一个矩阵A_i,因此我可以在此处使用torch.bmm()函数(1个点<-> 1个矩阵)。 Afaik,将需要为 A [idx] 的中间表示分配内存。通常,如果我的数据批处理中包含BN对象,则A [idx] =(B1 + ... + BN)* 3 * 3的大小可能会很大。
因此,我想知道是否可以避免矩阵A_i的复制。
我找到了有关Batch Mat的最常见问题。多仅假定固定的批次大小。 Here与我的问题相同,并提供了张量流解决方案。但是,该解决方案是使用tf.tile()实现的,它也可以复制矩阵。
总而言之,我的问题是关于批矩阵乘法,同时实现:
- dynamic batch size
- input shape: (B1+...+BN) x 3
- index shape: (B1+...+BN)
- memory efficiency
- probably w/out massive replication of matrix
我在这里使用pytorch,但我也接受其他实现。如果可以提高存储效率,我也接受在其他结构中表示输入(例如,相乘矩阵A)。