在Pytorch中实现自定义激活功能(例如Swish)时遇到问题。我应该如何在Pytorch中实现和使用自定义激活功能?
答案 0 :(得分:0)
您可以编写如下的自定义激活功能(例如加权Tanh)。
class weightedTanh(nn.Module):
def __init__(self, weights = 1):
super().__init__()
self.weights = weights
def forward(self, input):
ex = torch.exp(2*self.weights*input)
return (ex-1)/(ex+1)
如果使用兼容autograd
的操作,请不要担心反向传播。
答案 1 :(得分:0)
根据您要寻找的内容,有四种可能性。您将需要问自己两个问题:
问题1)您的激活功能是否具有可学习的参数?
如果是,则没有选择将激活函数创建为nn.Module
类,因为您需要存储这些权重。
如果否,则可以根据自己的方便随意创建普通函数或类。
问题2)您的激活功能可以表示为现有PyTorch功能的组合吗?
如果是,则可以简单地将其编写为现有PyTorch函数的组合,而无需创建定义梯度的backward
函数。
如果否,则需要手动编写渐变。
示例1:旋转功能
swish函数f(x) = x * sigmoid(x)
没有任何学习的权重,可以完全用现有的PyTorch函数编写,因此您可以简单地将其定义为一个函数:
def swish(x):
return x * torch.sigmoid(x)
,然后像使用torch.relu
或任何其他激活功能一样简单地使用它。
示例2:以学到的坡度挥舞
在这种情况下,您有一个学习的参数,即斜率,因此您需要对其进行分类。
class LearnedSwish(nn.Module):
def __init__(self, slope = 1):
super().__init__()
self.slope = slope * torch.nn.Parameter(torch.ones(1))
def forward(self, x):
return self.slope * x * torch.sigmoid(x)
示例3:向后
如果需要创建自己的渐变函数,可以查看以下示例:Pytorch: define custom function