pytorch:如何将3d张量与2d张量相乘

时间:2020-06-24 16:05:57

标签: pytorch

我得到了一个3D张量three和一个2D张量two,需要将它们相乘。例如,尺寸为:

three.shape = 4x100x700
two.shape = 4x100

输出形状应为:

output.shape = 4x100x700

因此,基本上,在output[a,b]中应该有700个标量,其计算方法是将three[a,b]中的所有700个标量与two[a,b]中的单个标量相乘。

1 个答案:

答案 0 :(得分:1)

您可以简单地向two添加一个额外的维度:

output = three * two.unsqueeze(-1)

还有其他语法,例如:

output = three * two[..., None]