我有两个火炬张量a和b。张量a的形状为[batch_size,emb_size],张量b的形状为[num_of_words,emb_size]。我想在这两个张量上执行逐元素乘积,而不是点积。
我注意到“ *”可以执行逐元素乘积运算,但不适合我的情况。
例如,batch_size = 3,emb_size = 2,num_of_words = 5。
a = torch.rand((3,2))
b = torch.rand((5,2))
我想得到类似的东西:
torch.cat([a[0]*b, a[1]*b, a[2]*b]).view(3, 5, 2)
但是我想以一种高效而优雅的方式做到这一点。
答案 0 :(得分:1)
您可以使用
public String toString()
{
return getClass().getName()+"@"+Integer.toHexString(hashCode());
}
PyTorch支持broadcasting semantics,但您需要确保单例尺寸在正确的位置。