pytorch /梯度计算/就地操作

时间:2020-07-02 12:36:31

标签: pytorch

我收到上述错误,但实际上找不到in_place操作。这是我要运行的代码。为了澄清:列表“ u”包含神经网络,但是我希望最后一个总是返回1s.t。断言会解决。

def generate_stopping_time_factors_from_path(self, x_input):
    local_N = x_input.shape[1]
    U = torch.empty(local_N)
    U[0] = torch.zeros(1)
    x = torch.empty(self.d, local_N)
    # x = torch.from_numpy(x_input) doesn't work for some reason

    h = torch.empty(self.N + 1)
    h[self.N] = 1

    for n in range(local_N):
        sum = torch.sum(U[0:n])
        x[:, n] = torch.tensor(x_input[:, n], requires_grad=True)
        if n < self.N:
            # all entries in "u" are nets
            h[n] = self.u[n](x[:, n])
        U[n] = h[n] * (torch.ones(1) - sum)

    assert torch.sum(U).item() == 1

    return U

但是,当我进一步测试它时,我发现第二个代码下面的2个代码不能以相同的错误工作,而第一个代码运行得很好(显然,它可以计算废话,但是可以计算)

def generate_stopping_time_factors_from_path(self, x_input):
    return self.u[0](x[:, 0])


def generate_stopping_time_factors_from_path(self, x_input):
    h = torch.empty(self.N + 1)
    h[0] = self.u[0](x[:, 0])
    return h[0]

在空张量中插入值是否真的算作就位操作?如果是这样,我如何重新编码我的代码来解决这个问题?我需要将网络的输出保存在本地,因为下一个网络的输出会将其用作一个因数。

0 个答案:

没有答案