pytorch自定义激活功能?

时间:2019-04-19 17:00:13

标签: pytorch

在Pytorch中实现自定义激活功能(例如Swish)时遇到问题。我应该如何在Pytorch中实现和使用自定义激活功能?

2 个答案:

答案 0 :(得分:0)

您可以编写如下的自定义激活功能(例如加权Tanh)。

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)

如果使用兼容autograd的操作,请不要担心反向传播。

答案 1 :(得分:0)

根据您要寻找的内容,有四种可能性。您将需要问自己两个问题:

问题1)您的激活功能是否具有可学习的参数?

如果,则没有选择将激活函数创建为nn.Module类,因为您需要存储这些权重。

如果,则可以根据自己的方便随意创建普通函数或类。

问题2)您的激活功能可以表示为现有PyTorch功能的组合吗?

如果,则可以简单地将其编写为现有PyTorch函数的组合,而无需创建定义梯度的backward函数。

如果,则需要手动编写渐变。

示例1:旋转功能

swish函数f(x) = x * sigmoid(x)没有任何学习的权重,可以完全用现有的PyTorch函数编写,因此您可以简单地将其定义为一个函数:

def swish(x):
    return x * torch.sigmoid(x)

,然后像使用torch.relu或任何其他激活功能一样简单地使用它。

示例2:以学到的坡度挥舞

在这种情况下,您有一个学习的参数,即斜率,因此您需要对其进行分类。

class LearnedSwish(nn.Module):
    def __init__(self, slope = 1):
        super().__init__()
        self.slope = slope * torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.slope * x * torch.sigmoid(x)

示例3:向后

如果需要创建自己的渐变函数,可以查看以下示例:Pytorch: define custom function