为什么我所有输入的LSTM的隐藏状态/输出都相同?

时间:2019-06-15 06:33:01

标签: python pytorch lstm text-classification

我使用python3和pytorch构建基本的LSTM来处理文本分类问题。但是,我得到的结果是所有输入的预测结果都相同。我检查了每个批次的隐藏状态,发现批次的所有隐藏状态都相同。我认为不同的输入应该具有不同的最终隐藏状态,对吗? 那么,如何纠正我的代码?

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()

        self.lstm=nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size
            )
        self.batch_size=batch_size
        self.hidden_size=hidden_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.requires_grad = True
        self.fc=nn.Linear(self.hidden_size,num_classes)

    def init_hidden(self,batch_size_=None):
        if batch_size_ is None:
            batch_size_=self.batch_size
        h0 = autograd.Variable(torch.zeros(1, batch_size_, self.hidden_size))
        c0 = autograd.Variable(torch.zeros(1, batch_size_, self.hidden_size))
        return (h0,c0)

    def forward(self,X):
        X=self.word_embeddings(X)
        X=X.permute(1,0,2)
        print(X)
        self.hidden=self.init_hidden(X.size()[1])
        output, (final_hidden_state, final_cell_state) = self.lstm(X,self.hidden)
        return self.fc(final_hidden_state[-1])

0 个答案:

没有答案