使用Torch张量屏蔽对角线到特定值

时间:2018-03-27 12:00:04

标签: python numpy torch pytorch

正如问题标题所述,我如何用火炬中的值填充对角线?

a = np.zeros((3, 3), int)
np.fill_diagonal(a, 5)

array([[5, 0, 0],
       [0, 5, 0],
       [0, 0, 5]])

我知道torch.diag()会返回对角线,但我如何使用它作为掩码来分配新值,超出了我的范围,我无法在此处或火炬文档中找到答案。< / p>

2 个答案:

答案 0 :(得分:2)

一种方法:

>>> import torch
>>> n = 3
>>> t = torch.zeros((n,n))
>>> t[torch.eye(n).byte()] = 5
>>> t

 5  0  0
 0  5  0
 0  0  5
[torch.FloatTensor of size 3x3]

答案 1 :(得分:2)

您可以在 PyTorch 中使用 fill_diagonal_ 执行此操作:

>>> a = torch.zeros(3, 3)
>>> a.fill_diagonal_(5)
tensor([[5, 0, 0],
        [0, 5, 0],
        [0, 0, 5]])