如何在pytorch中按行乘以标量?

时间:2018-12-31 13:11:49

标签: pytorch tensor scalar

当我有形状为m的张量[12, 10]和形状为s的标量向量[12]时,如何将m的每一行相乘与s中对应的标量?

3 个答案:

答案 0 :(得分:4)

您需要添加相应的单例尺寸:

m * s[:, None]

s[:, None]的大小为(12, 1),当将(12, 10)张量乘以(12, 1)张量时pytoch知道沿第二个单例broadcast s确定尺寸并正确执行“基于元素的”产品。

答案 1 :(得分:1)

您可以将向量广播到更高维的张量 like so

def row_mult(t,vector):
    extra_dims = (1,)*(t.dim()-1)
    return t * vector.view(-1, *extra_dims)

答案 2 :(得分:0)

如果您事先知道尺寸数并且可以将正确数量的None进行硬编码,Shai的答案就可以使用。可以扩展到需要额外的尺寸:

mask = (torch.rand(12) > 0.5).int()  
data = (torch.rand(12, 2, 3, 4))
result = data * mask[:,None,None,None]

result.shape                  # torch.Size([12, 2, 3, 4])
mask[:,None,None,None].shape  # torch.Size([12, 1, 1, 1])

如果要处理尺寸可变或未知的数据,则可能需要手动将mask扩展到正确的形状

mask = (torch.rand(12) > 0.5).int()
while mask.dim() < data.dim(): mask.unsqueeze_(1)
result = data * mask

result.shape  # torch.Size([12, 2, 3, 4])
mask.shape    # torch.Size([12, 1, 1, 1])

这是一个丑陋的解决方案,但是确实有效。对于可变数量的尺寸,可能有一种更优雅的方法来正确地重塑mask张量内联