我试图实现一个seq2seq编码器解码器,它从英语句子重建相同的英语句子。解码器工作正常(由老师强制训练,重建有效句子),但编码器为任何句子返回相同的隐藏编码。
我正在使用pytorch,我想知道这是否是一个常见问题。可能的原因是什么?
hidden_size = 300
output_size = vocabularySize
class EncoderLSTM(nn.Module):
def __init__(self,emb_dim=vocabularySize):
super(EncoderLSTM,self).__init__()
self.lstm = nn.LSTM(emb_dim,hidden_size)
def forward(self,X,h):
X=X.view(1,1,-1)
#print(X.shape,h[0].shape,h[1].shape)
st,h2=self.lstm(X,h)
#print(st.shape)
return h2
def initHidden(self):
result = (Variable(torch.zeros(1, 1,hidden_size)),Variable(torch.zeros(1, 1, hidden_size)))
return result
...
hidden = encoder.initHidden()
for i in range(target_length):
encoder_input=target_variable[0][target_length-1-i]
hidden = encoder.forward(encoder_input,hidden)
...
for i in range(1,target_length):
pred,hidden = decoder.forward(decoder_input,hidden)
target = target_variable[0][i].unsqueeze(0)
values,ids = target.data.topk(1)
target_id = Variable(torch.from_numpy(np.array([ids[0][0]])).long())
loss+= criterion(pred[0],target_id)
...
loss/=target_length
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
答案 0 :(得分:0)
那是不可能的。您一定不能查看每个新输入句子的整个隐藏状态。如果驱动不同的输出,则它们必须不同。也许您正在查看整个批次的某种最终隐藏状态。