更改Torch.Tensor特定部分的值时出现RuntimeError

时间:2019-02-27 03:33:58

标签: python pytorch tensor

说我有一个3维的张量x,它以零初始化:

x = torch.zeros((2, 2, 2))

和另外3个三维张量y

y = torch.ones((2, 1, 2))

我正在尝试像这样更改x[0]x[1]第一行的值

x[:, 0, :] = y

但我收到此错误:

RuntimeError: expand(torch.FloatTensor{[2, 1, 2]}, size=[2, 2]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

就像张量y受到某种程度的挤压一样。有办法解决吗?

2 个答案:

答案 0 :(得分:1)

这是您想要的吗?

x = torch.arange(0, 8).reshape((2,2,2))
y = torch.ones((2,2))
x2 = x.permute(1,0,2)
x2[0] = y
x_target = x2.permute(1,0,2)

x的第一行的值被y更改。

答案 1 :(得分:0)

我找到了一种直接的方法:

x[:, 0, :] = y[:, 0, :]