我正在阅读Pytorch中LSTM的实现。 代码如下:
lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
我不明白为什么隐藏状态是由两个张量而不是一个张量的元组定义的?由于隐藏层只是前馈神经网络的一个向量层。
答案 0 :(得分:2)
除了隐藏状态,LSTM还具有单元状态C。因此,我认为通过了一个元组。参见https://pytorch.org/docs/stable/nn.html#lstmcell。
如果您不传递C,它将被视为全零。
请注意,对于LSTM,GRU或RNN没有C的情况就是如此。