我有批量数据,想要dot()
数据。 W是可训练的参数。
如何在批量数据和权重之间进行点击?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
这个怎么样?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
答案 0 :(得分:1)
展开W
张量以匹配data
张量的形状。以下应该有效。
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)
修改:您的更新代码是正确的。由于您要将W
转换为Bxhid_dimx1
且数据的形状为Bxdxhid_dim
,因此批量矩阵乘法将产生Bxdx1
,这基本上是{{之间的点积1}}参数和W
(data
)中的所有行向量。