如何在Pytorch中批量处理每个矩阵的转置?

时间:2019-08-15 15:29:50

标签: pytorch

说我有4批5x3矩阵。因此,这些张量的尺寸为4x5x3。如何在每个批次中对每个矩阵进行转置。那么将其转换为4x3x5吗?

3 个答案:

答案 0 :(得分:2)

torch.transpose(input, dim0, dim1)将交换两个尺寸。

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893,  0.5809],
        [-0.1669,  0.7299,  0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
        [-0.9893,  0.7299],
        [ 0.5809,  0.4942]])

如果希望交换更多尺寸,请使用torch.Tensor.permute

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])

答案 1 :(得分:1)

出于性能考虑,我将在此处放弃一些基准测试。使用OP答案中建议的相同张量。

In[2]: import torch
In[3]: x = torch.randn(2, 3, 5)
In[4]: x.size()
Out[4]: torch.Size([2, 3, 5])
In[5]: %timeit x.permute(1, 0, 2)
1.03 µs ± 41.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In[6]: %timeit torch.transpose(x, 0, 1)
892 ns ± 9.61 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In[7]: torch.transpose(x, 0, 1).equal(x.permute(1, 0, 2))
Out[7]: True

很明显torch.transpose更快,因此建议尽可能使用它。

答案 2 :(得分:0)

作为警告,请执行以下操作:

torch.transpose(x, 1, 2)

如果你有一个大小为 [B, T, D] 的张量。例如,执行 .t() 仅适用于矩阵。这对于在两个矩阵张量之间执行矩阵倍数很有用:

att = torch.transopose(x, 1, 2) @ x

或者如果您希望可变长度序列彼此面对并消除该维度。