Pytorch-对角矩阵块设置有效吗?

时间:2019-02-24 20:42:22

标签: pytorch

我的张量A的大小为[N x 3 x 3],矩阵B的大小为[N * 3 x N * 3]

我想要复制A-> B的内容,以便基本填充对角元素,并且我想有效地做到这一点:

应该看起来像这样填充B:

enter image description here

因此,每个[i,3,3]都沿着线的对角线填充到B中的每个[3x3]部分中。

我该怎么做?对于实时应用,应尽可能有效。我可以编写一个CUDA内核来做到这一点,但是我更喜欢用一些特殊的Pytorch函数来完成它

3 个答案:

答案 0 :(得分:0)

我在Gist上进行了矢量化实现:block_diag.py

有关最新版本,请检查我的numpytorch.py库中pylabyk中的block_diag()。

答案 1 :(得分:0)

使用torch.block_diag()

# Setup
A = torch.ones(3,3,3, dtype=int)

# Unpack blocks and apply
B = torch.block_diag(*A)
>>> B
tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1]])

答案 2 :(得分:-1)

这是一个简单的示例,不确定真正的大张量的性能:

代码:

import torch

# Create some tensors
N = 3
A = torch.ones(N, 3, 3)
A[1] *= 2
A[2] *= 3
B = torch.zeros(N*3, N*3)


def diagonalizer(A, B):
    N = A.shape[0]
    i_min = 0
    j_min = 0
    i_max = 3
    j_max = 3

    for t in range(N):
        B[i_min:i_max, j_min:j_max] = A[t]  # NOTE! this is inplace operation

        # do the step:
        i_min += 3
        j_min += 3
        i_max += 3
        j_max += 3


print('before:\n', B, sep='')

diagonalizer(A, B)

print('after:\n', B, sep='')

输出:

before:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])
after:
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.],
        [0., 0., 0., 0., 0., 0., 3., 3., 3.]])