我是Pytorch和RNN的新手。我正在学习如何使用RNN来预测数字作为视频的教程:https://www.youtube.com/watch?v=MKA6v99uYKY
在他的代码中,他使用python 3并执行解码,如:
out_unembedded = out.view(-1, hidden_size) @ embedding.weight.transpose(0,1)
我正在使用Python 2并尝试代码:
out_unembedded = out.view(-1, hidden_size).dot( embedding.weight.transpose(0,1))
但似乎不对,然后我尝试像这样解码:
import torch
import torch.nn as nn
from torch.autograd import Variable
word2id = {'hello': 0, 'world': 1, 'I': 2, 'am': 3,'writing': 4,'pytorch': 5}
embeds = nn.Embedding(6, 3)
word_embed = embeds(Variable(torch.LongTensor([word2id['am']])))
id2word = {v: k for k, v in word2id.iteritems()}
index = 0
for row in embeds.weight.split(1):
if(torch.min( torch.eq(row.data,word_embed.data) ) == 1):
print index
print id2word[index]
index+=1
有更专业的方法吗?谢谢!
------------ UPDATE ------------
我找到了在Python 2中替换@的正确方法:
out_unembedded = torch.mm( embedded_output.view(-1, hidden_size),embedding.weight.transpose(0, 1))