我想用a
和b
生成新的a○b向量(○表示元素明智的乘法)。我的代码如下,但由于for
,性能看起来很糟糕。有没有有效的方法?
a = torch.rand(batch_size, a_len, hid_dim)
b = torch.rand(batch_size, b_len, hid_dim)
# a_elmwise_mul_b = torch.zeros(batch_size, a_len, b_len, hid_dim)
for sample in range(batch_size):
for ai in range(a_len):
for bi in range(b_len):
a_elmwise_mul_b[sample, ai, bi] = torch.mul(a[sample, ai], b[sample, bi])
我更新了我的代码,请参考艾哈迈德!谢谢。
N = 16
hid_dim = 50
a_seq_len = 10
b_seq_len = 20
a = torch.randn(N, a_seq_len, hid_dim)
b = torch.randn(N, b_seq_len, hid_dim)
shape = (N, a_seq_len, b_seq_len, hid_dim)
a_dash = a.unsqueeze(2) # (N, a_len, 1, hid_dim)
b_dash = b.unsqueeze(1) # (N, 1, b_len, hid_dim)
a_dash = a_dash.expand(shape)
b_dash = b_dash.expand(shape)
print(a_dash.size(), b_dash.size())
mul = a_dash * b_dash
print(mul.size())
----------
torch.Size([16, 10, 20, 50]) torch.Size([16, 10, 20, 50])
torch.Size([16, 10, 20, 50])
答案 0 :(得分:2)
从您的问题定义中,您似乎希望将两个张量相乘,比如形状A
和B
的{{1}}和AxE
,并希望得到张量形状BxE
。这意味着你想要将每一个张量AxBxE
与整个张量A
相乘。如果它是正确的,那么我们不称其为元素乘法。
您可以按照以下步骤完成目标。
B
此处,import torch
# batch_size = 16, a_len = 10, b_len = 20, hid_dim = 50
a = torch.rand(16, 10, 50)
b = torch.rand(16, 20, 50)
c = a.unsqueeze(2).expand(*a.size()[:-1], b.size(1), a.size()[-1])
d = b.unsqueeze(1).expand(b.size()[0], a.size(1), *b.size()[1:])
print(c.size(), d.size())
print(c.size(), d.size())
mul = c * d # shape of c, d: 16 x 10 x 20 x 50
print(mul.size()) # 16 x 10 x 20 x 50
张量是您想要的结果。只是为了澄清一下,上面两行用于mul
和c
计算,相当于:
d