Pytorch nn。嵌入错误

时间:2018-07-21 12:27:18

标签: pytorch word-embedding

我正在阅读Word Embedding上的pytorch文档。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(5)

word_to_ix = {"hello": 0, "world": 1, "how":2, "are":3, "you":4}
embeds = nn.Embedding(2, 5)  # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor(word_to_ix["hello"], dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)

输出:

tensor([-0.4868, -0.6038, -0.5581,  0.6675, -0.1974])

这看起来不错,但是如果我用

代替line lookup_tensor
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)

我收到以下错误消息:

RuntimeError: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1524590658547/work/aten/src/TH/generic/THTensorMath.c:343

我不明白为什么它在行hello_embed = embeds(lookup_tensor)上给运行时错误。

1 个答案:

答案 0 :(得分:3)

当您声明 embeds = nn.Embedding(2,5)时,词汇的大小为2,嵌入的大小为5。即,每个单词将由大小为5的向量表示,并且只有vocab中有2个字。

lookup_tensor = torch.tensor(word_to_ix [“ how”],dtype = torch.long)嵌入将尝试查找与vocab中第三个单词相对应的向量,但嵌入的vocab大小为2.这就是为什么您会收到错误消息。

如果您声明 embeds = nn.Embedding(5,5),它应该可以正常工作。