我正在寻找一种简单的方法来使用pytorch库中存在的激活函数,但是要使用某种参数。例如:
Tanh(x / 10)
我找到解决方案的唯一方法是完全从头开始实现自定义功能。有没有更好/更优雅的方式来做到这一点?
编辑:
我正在寻找将函数Tanh(x / 10)而不是普通Tanh(x)附加到模型中的方法。这是相关的代码块:
Enumerable
答案 0 :(得分:1)
您可以将其内联到自定义层中,而不是将其定义为特定功能。
例如,您的解决方案可能如下所示:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 3)
self.fc3 = nn.Softmax()
def forward(self, x):
return self.fc3(self.fc2(torch.tanh(self.fc1(x)/10)))
其中torch.tanh(output/10)
内联到模块的前进功能中。
答案 1 :(得分:0)
您可以使用乘法参数创建图层:
import torch
import torch.nn as nn
class CustomTanh(nn.Module):
#the init method takes the parameter:
def __init__(self, multiplier):
self.multiplier = multiplier
#the forward calls it:
def forward(self, x):
x = self.multiplier * x
return torch.tanh(x)
使用CustomTanh(1/10)
而不是nn.Tanh()
将其添加到模型中。