如何在PyTorch中执行基于元素的乘积?

时间:2019-08-16 04:22:23

标签: python pytorch

我有两个火炬张量a和b。张量a的形状为[batch_size,emb_size],张量b的形状为[num_of_words,emb_size]。我想在这两个张量上执行逐元素乘积,而不是点积。

我注意到“ *”可以执行逐元素乘积运算,但不适合我的情况。

例如,batch_size = 3,emb_size = 2,num_of_words = 5。

a = torch.rand((3,2))
b = torch.rand((5,2))

我想得到类似的东西:

torch.cat([a[0]*b, a[1]*b, a[2]*b]).view(3, 5, 2)

但是我想以一种高效而优雅的方式做到这一点。

1 个答案:

答案 0 :(得分:1)

您可以使用

public String toString()
{
      return getClass().getName()+"@"+Integer.toHexString(hashCode());
}

PyTorch支持broadcasting semantics,但您需要确保单例尺寸在正确的位置。