pytorch最后的nn.Linear()有什么作用,为什么有必要?

时间:2019-09-23 16:01:25

标签: neural-network pytorch lstm recurrent-neural-network

我正在使用一些训练lstm生成序列的代码。训练模型后,将调用lstm()方法:

x = some_input
lstm_output, (h_n, c_n) = lstm(x, hc) 
funcc = nn.Linear(in_features=lstm_num_hidden,
                  output_features=vocab_size,
                  bias=True)
func_output = func(lstm_output)

我看过nn.Linear()的文档,但我仍然不了解此转换正在做什么以及为什么有必要。如果lstm已经过训练,那么它给出的输出应该已经具有预先确定的维度。此输出(lstm_output)将是生成的序列,或者在我的情况下是向量数组。我在这里想念东西吗?

1 个答案:

答案 0 :(得分:4)

在这里,线性层将LSTM产生的隐藏状态表示(lstm_output)转换为大小vocab_size的向量。您的理解也许是错误的。 Linear层应与LSTM一起接受培训。

我猜您正在尝试生成标记(单词)序列,因此应该在Linear层后面进行Softmax操作,以预测词汇表上的概率分布。