如何在PyTorch中使用numpy函数作为损失函数并避免在运行时出错?

时间:2019-01-13 02:34:59

标签: python numpy pytorch

对于我的任务,我不需要计算梯度。我只是在损失评估中将numpy函数(nn.L1Loss替换为corrcoef,但出现以下错误:

RuntimeError: Can’t call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

我无法弄清楚应该如何分离图表(我尝试过torch.Tensor.detach(np.corrcoef(x, y)),但是仍然遇到相同的错误。我最终使用torch.no_grad包装了所有内容,如下所示:

with torch.no_grad():
    predFeats = self.forward(x)
    targetFeats = self.forward(target)
    loss = torch.from_numpy(np.corrcoef(predFeats.cpu().numpy().astype(np.float32), targetFeats.cpu().numpy().astype(np.float32))[1][1])

但是这次我得到以下错误:

TypeError: expected np.ndarray (got numpy.float64)

我想知道,我在做什么错了?

1 个答案:

答案 0 :(得分:1)

TL; DR

with torch.no_grad():
    predFeats = self(x)
    targetFeats = self(target)
    loss = torch.tensor(np.corrcoef(predFeats.cpu().numpy(),
                                    targetFeats.cpu().numpy())[1][1]).float()

通过从计算图上分离张量(RuntimeErrorpredFeats),可以避免第一个targetFeats。 即获取没有梯度和梯度函数(grad_fn)的张量数据的副本。

所以,而不是

torch.Tensor.detach(np.corrcoef(x.numpy(), y.numpy())) # Detaches a newly created tensor!
# x and y still may have gradients. Hence the first error.

什么都不做,做

# Detaches x and y properly
torch.Tensor(np.corrcoef(x.detach().numpy(), y.detach().numpy()))

但是让我们不要为所有分队而烦恼。

就像您正确修复的那样,让我们​​禁用渐变。

torch.no_grad()

现在,计算特征。

predFeats = self(x) # No need for the explicit .forward() call
targetFeats = self(target)

我发现打破最后一行很有帮助。

loss = np.corrcoef(predFeats.numpy(), targetFeats.numpy()) # We don't need to detach

# Notice that we don't need to cast the arguments to fp32
# since the `corrcoef` casts them to fp64 anyway.

print(loss.shape, loss.dtype) # A 2-dimensional fp64 matrix

loss = loss[1][1]
print(type(loss)) # Output: numpy.float64
# Loss now just a simple fp64 number

这就是问题所在!

因为,当我们这样做

loss = torch.from_numpy(loss)

我们传入一个数字(numpy.float64),但它期望一个numpy张量(np.ndarray)。

如果您使用的是PyTorch 0.4或更高版本,则对标量有内置支持。

只需将from_numpy()方法替换为通用的tensor()创建方法。

loss = torch.tensor(loss)

P.S。您可能还需要查看在rowvar=False中设置corrcoef,因为PyTorch张量中的行通常表示观测值。