在多个维度上获得对角元素

时间:2018-02-09 16:08:43

标签: python numpy pytorch

我希望将尺寸(n×n×m×m)的张量T变换为尺寸(n×m×m)的张量U,同时仅在(N×N)块(即Uik1 = Tiik1)上检索T的对角线元素。 torch.diag()仅适用于二维张量,我真的没有看到如何在没有循环元素索引的情况下做到这一点(我想避免考虑到我认为它是低效的计算) 。很明显,我想要对以下代码进行矢量化:

U = torch.zeros(n, m, m)
for i in range(n):
    for k in range(m):
        for l in range(m):
            U[i][k][l] = T[i][i][k][l]

我对pytorch完全陌生,我尝试了很多功能组合,但没有一个给我一个令人满意的结果。有人有想法吗?

2 个答案:

答案 0 :(得分:0)

您可以使用np.meshgrid

生成索引
i, k, l = np.meshgrid(range(n), range(m), range(m))
U[i, k, l] = T[i, i, k, l]

为了完整性,我做了:

n = 3
m = 5

T = torch.arange(n * n * m * m).view(n, n, m, m)
U = torch.zeros(n, m, m)
U_ = torch.zeros(n, m, m)

i, k, l = np.meshgrid(range(n), range(m), range(m))

U_[i, k, l] = T[i, i, k, l]

for i in range(n):
    for k in range(m):
        for l in range(m):
            U[i][k][l] = T[i][i][k][l]

U = U.view(-1)
U_ = U_.view(-1)

print ((U == U_).all())

输出为True,所以我认为它是正确的。

答案 1 :(得分:0)

应用于二维矩阵时,torch.diag()torch.diagonal() 的别名。

diagonal 本身允许您指定对角线取自任意秩张量的哪两个维度,默认情况下为 0 和 1:

U = T.diagonal()