如何有效地解码PyTorch中的嵌入?

时间:2018-01-02 11:47:33

标签: python artificial-intelligence pytorch rnn

我是Pytorch和RNN的新手。我正在学习如何使用RNN来预测数字作为视频的教程:https://www.youtube.com/watch?v=MKA6v99uYKY

在他的代码中,他使用python 3并执行解码,如:

out_unembedded = out.view(-1, hidden_size) @ embedding.weight.transpose(0,1)

我正在使用Python 2并尝试代码:

out_unembedded = out.view(-1, hidden_size).dot( embedding.weight.transpose(0,1))

但似乎不对,然后我尝试像这样解码:

import torch
import torch.nn as nn
from torch.autograd import Variable

word2id = {'hello': 0, 'world': 1, 'I': 2, 'am': 3,'writing': 4,'pytorch': 5}
embeds = nn.Embedding(6, 3)
word_embed = embeds(Variable(torch.LongTensor([word2id['am']])))

id2word = {v: k for k, v in word2id.iteritems()}
index = 0
for row in embeds.weight.split(1):
    if(torch.min( torch.eq(row.data,word_embed.data) ) == 1):
        print index
        print id2word[index]
    index+=1

有更专业的方法吗?谢谢!

------------ UPDATE ------------

我找到了在Python 2中替换@的正确方法:

out_unembedded = torch.mm( embedded_output.view(-1, hidden_size),embedding.weight.transpose(0, 1))

1 个答案:

答案 0 :(得分:1)

我终于弄明白了这个问题。两种解码方法不同。

第一个使用

  

@

做点积。它不是搜索精确的解码,而是通过点积计算余弦相似度,并找到最相似的单词。点积之后的值表示目标与具有这种索引的单词之间的相似性。等式是:

enter image description here

构建哈希映射的第二种方法是使用精确编码来查找索引。