我应该如何理解nn.Embeddings参数num_embeddings和embedding_dim?

时间:2019-11-09 07:27:40

标签: pytorch

我试图适应PyTorch nn模块中的Embedding类。

我注意到很多其他人都遇到了与我自己相同的问题,因此在PyTorch论坛和Stack Overflow上发布了问题,但我仍然有些困惑。

根据official documentation,传递的参数是num_embeddingsembedding_dim,它们分别表示字典(或词汇)的大小以及希望嵌入的维数分别。

我很困惑的是我应该如何准确地解释这些。例如,我运行的小型练习代码:

import torch
import torch.nn as nn


embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)

a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # (2, 4)

b = torch.LongTensor([[1, 2, 3], [2, 3, 1], [4, 5, 6], [3, 3, 3], [2, 1, 2],
                      [6, 7, 8], [2, 5, 2], [3, 5, 8], [2, 3, 6], [8, 9, 6],
                      [2, 6, 3], [6, 5, 4], [2, 6, 5]]) # (13, 3)

c = torch.LongTensor([[1, 2, 3, 2, 1, 2, 3, 3, 3, 3, 3],
                      [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]) # (2, 11)

当我通过a变量运行bcembedding时,我得到了形状为(2, 4, 3)(13, 3, 3)的嵌入结果,(2, 11, 3)

让我感到困惑的是,我想到我们拥有的样本数量超过了预定义的嵌入数量,我们应该得到一个错误吗?由于我定义的embedding具有10嵌入,b不应给我一个错误,因为它是一个包含13个3维单词的张量?

1 个答案:

答案 0 :(得分:1)

在您的情况下,这是输入张量的解释方式:

a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # 2 sequences of 4 elements

此外,这就是您的嵌入层的解释方式:

embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) # 10 distinct elements and each those is going to be embedded in a 3 dimensional space

因此,只要输入张量有10个以上的元素,只要它们在[0, 9]范围内,都没有关系。例如,如果我们创建两个元素的张量,例如:

d = torch.LongTensor([[1, 10]]) # 1 sequence of 2 elements

当这个张量通过嵌入层时,我们将得到以下错误:

  

RuntimeError:索引超出范围:尝试访问具有9行的表之外的索引10

总结一下,num_embeddings是词汇表中唯一元素的总数,而embedding_dim是经过嵌入层的每个嵌入矢量的大小。因此,只要张量中的每个元素都在[0, 9]范围内,就可以拥有10个以上元素的张量,因为您定义的词汇量为10个元素。