如何在PyTorch中替换infs以避免nan渐变

时间:2019-06-19 12:25:47

标签: python pytorch

我需要计算log(1 + exp(x)),然后对其使用自动微分。但是对于x而言,由于幂运算,它会输出inf

>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> x.exp().log1p()
tensor([0.6931, 1.3133,    inf], grad_fn=<Log1PBackward>)

由于log(1 + exp(x)) ≈ x代表大x,因此我认为可以使用infsx替换为torch.where。但是当这样做时,对于太大的值,我仍然得到nan。您知道为什么会发生这种情况以及是否有其他方法可以使它起作用吗?

>>> exp = x.exp()
>>> y = x.where(torch.isinf(exp), exp.log1p())  # Replace infs with x
>>> y  # No infs
tensor([  0.6931,   1.3133, 100.0000], grad_fn=<SWhereBackward>)
>>> y.sum().backward()  # Automatic differentiation
>>> x.grad  # Why is there a nan and how can I get rid of it?
tensor([0.5000, 0.7311,    nan])

3 个答案:

答案 0 :(得分:1)

  

但是对于x太大,由于求幂,它会输出inf

这就是为什么x永远不要太大。理想情况下,它应该在[-1,1]范围内。 如果不是这种情况,则应标准化输入。

答案 1 :(得分:0)

我发现一种解决方法是手动实现一个Log1PlusExp函数及其后向功能。但这并不能解释问题中torch.where的不良行为。

>>> class Log1PlusExp(torch.autograd.Function):
...     """Implementation of x ↦ log(1 + exp(x))."""
...     @staticmethod
...     def forward(ctx, x):
...         exp = x.exp()
...         ctx.save_for_backward(x)
...         return x.where(torch.isinf(exp), exp.log1p())
...     @staticmethod
...     def backward(ctx, grad_output):
...         x, = ctx.saved_tensors
...         return grad_output / (1 + (-x).exp())
... 
>>> log_1_plus_exp = Log1PlusExp.apply
>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> log_1_plus_exp(x)  # No infs
tensor([  0.6931,   1.3133, 100.0000], grad_fn=<Log1PlusExpBackward>)
>>> log_1_plus_exp(x).sum().backward()
>>> x.grad  # And no nans!
tensor([0.5000, 0.7311, 1.0000])

答案 2 :(得分:-1)

如果x> = 20,则函数输出约为x。 使用PyTorch方法torch.softplus。 它可以解决问题。