我使用python3和pytorch构建基本的LSTM来处理文本分类问题。但是,我得到的结果是所有输入的预测结果都相同。我检查了每个批次的隐藏状态,发现批次的所有隐藏状态都相同。我认为不同的输入应该具有不同的最终隐藏状态,对吗? 那么,如何纠正我的代码?
class LSTM(nn.Module):
def __init__(self):
super(LSTM,self).__init__()
self.lstm=nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_size
)
self.batch_size=batch_size
self.hidden_size=hidden_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.word_embeddings.weight.requires_grad = True
self.fc=nn.Linear(self.hidden_size,num_classes)
def init_hidden(self,batch_size_=None):
if batch_size_ is None:
batch_size_=self.batch_size
h0 = autograd.Variable(torch.zeros(1, batch_size_, self.hidden_size))
c0 = autograd.Variable(torch.zeros(1, batch_size_, self.hidden_size))
return (h0,c0)
def forward(self,X):
X=self.word_embeddings(X)
X=X.permute(1,0,2)
print(X)
self.hidden=self.init_hidden(X.size()[1])
output, (final_hidden_state, final_cell_state) = self.lstm(X,self.hidden)
return self.fc(final_hidden_state[-1])