我想要一个RNN,输入大小为7,隐藏大小为10,输出大小为2。
因此,对于形状为99x1x7
的输入,我希望输出形状为99x1x2
。
仅对于RNN,我得到:
model = nn.RNN(input_size=7, hidden_size=10, num_layers=1)
output,hn=model(torch.rand(99,1,7))
print(output.shape) #torch.Size([99, 1, 10])
print(hn.shape) #torch.Size([ 1, 1, 10])
所以我认为我仍然必须在其后加上Linear
:
model = nn.Sequential(nn.RNN(input_size=7, hidden_size=10, num_layers=1),
nn.Linear(in_features=10, out_features=2))
model(torch.rand(99,1,7))
Traceback (most recent call last):
File "train_rnn.py", line 80, in <module>
main()
File "train_rnn.py", line 25, in main
model(torch.rand(99,1,7))
File "/home/.../virtual-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/.../virtual-env/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/.../virtual-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/.../virtual-env/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 92, in forward
return F.linear(input, self.weight, self.bias)
File "/home/.../virtual-env/lib/python3.6/site-packages/torch/nn/functional.py", line 1404, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'
我猜这是因为Linear
接收到RNN.forward
产生的元组。但是我应该如何将两者结合起来?
答案 0 :(得分:0)
来自pytorch文档https://pytorch.org/docs/stable/nn.html?highlight=rnn#torch.nn.RNN
输出的形状为seq_len, batch, num_directions * hidden_size
因此,根据您的需要,您可以添加fc层以获得大小为2的输出。
基本上,Sequential
会将每个模型应用到next_one的输出之上,因此,您不得使用Sequential
或创建可用于序列的特殊线性层,以下应可工作:>
class seq_Linear(nn.module):
def __init__(self, linear):
self.linear = linear
# To apply on every hidden state
def forward(self, x):
return torch.stack([self.linear(hs) for hs in x])
# To apply on the last hidden state
def forward(self, x):
return self.linear(x[-1])
,然后在代码中用seq_Linear(nn.Linear)替换nn.Linear。
编辑:如果要创建大小为2的输出序列,最好的方法可能是在第一个RNN的顶部堆叠另一个input_size 10和output_size 2的RNN,它们应该可堆叠在{{1 }},没有任何麻烦。