在PyTorch中用向量替换对角元素

时间:2018-03-22 12:49:26

标签: python pytorch diagonal

我一直在寻找与PyTorch相当的以下内容,但我找不到任何东西。

{{1}}

我想没有办法用Pytorch以这种优雅的方式替换对角元素。

3 个答案:

答案 0 :(得分:1)

我认为现在没有实现这样的功能。但是,您可以使用mask实现相同的功能,如下所示。

# Assuming v to be the vector and a be the tensor whose diagonal is to be replaced
mask = torch.diag(torch.ones_like(v))
out = mask*torch.diag(v) + (1. - mask)*a

因此,您的实现将类似于

L_1 = torch.tril(torch.randn((D, D)))
v = torch.exp(torch.diag(L_1))
mask = torch.diag(torch.ones_like(v))
L_1 = mask*torch.diag(v) + (1. - mask)*L_1

不像numpy那样优雅,但也不是太糟糕。

答案 1 :(得分:1)

有一种更简单的方法

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

实际上,我们必须自己生成对角线索引。

用法示例:

dest_matrix = torch.randint(10, (3, 3))
source_vector = torch.randint(100, 200, (len(dest_matrix), ))
print('dest_matrix:\n', dest_matrix)
print('source_vector:\n', source_vector)

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

print('result:\n', dest_matrix)

# dest_matrix:
#  tensor([[3, 2, 5],
#         [0, 3, 5],
#         [3, 1, 1]])
# source_vector:
#  tensor([182, 169, 147])
# result:
#  tensor([[182,   2,   5],
#         [  0, 169,   5],
#         [  3,   1, 147]])

如果dest_matrix不为正方形,则必须取min(dest_matrix.size())中的len(dest_matrix)而不是range()

不如numpy优雅,但这不需要存储新的索引矩阵。

是的,这会保留渐变

答案 2 :(得分:0)

您可以使用 diagonal() 提取对角线元素,然后使用 copy_() 分配转换后的值:

new_diags = L_1.diagonal().exp()
L_1.diagonal().copy_(new_diags)