PyTorch中的这一行代码做什么?
normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
Y = A.div(normA.view(batchSize, 1, 1).expand_as(A))
通常它应该是第二个术语:
torch.div(input, value, out=None) → Tensor
答案 0 :(得分:2)
您的问题还不清楚,因为您没有提到张量A
的形状和normA
的形状。但我猜是这样:
A
是形状为(batchSize, X, Y)
的张量normA
是A
的所有批处理元素的范数张量,其形状为(batchSize)
。因此,您可以使用以下语句规范张量A
。
A.div(normA.view(batchSize, 1, 1).expand_as(A))
首先将normA.view(batchSize, 1, 1).expand_as(A)
转换为形状(batchSize, X, Y)
的张量,然后将A
除以结果张量。
一个例子(根据我的猜测创建):
batchSize = 8
A = torch.randn(batchSize, 5, 5)
normA = A.norm(dim=-1).norm(dim=-1)
print(normA.size()) # torch.Size([8])
normA = normA.view(batchSize, 1, 1).expand_as(A)
print(normA.size()) # torch.Size([8, 5, 5])
A = A.div(normA)