pytorch嵌入索引超出范围

时间:2019-05-06 18:28:11

标签: python neural-network nlp pytorch recurrent-neural-network

我在https://cs230-stanford.github.io/pytorch-nlp.html处关注本教程。在其中使用nn.Module创建一个神经网络模型,该模型带有一个嵌入层,该层在此处初始化

self.embedding = nn.Embedding(params['vocab_size'], params['embedding_dim'])

vocab_size是训练样本的总数,即4000。embedding_dim是50。forward方法的相关内容如下

def forward(self, s):
        # apply the embedding layer that maps each token to its embedding
        s = self.embedding(s)   # dim: batch_size x batch_max_len x embedding_dim

将批处理传递给模型时出现此异常 model(train_batch) train_batch是一个尺寸为batch_size x batch_max_len的数字数组。每个样本都是一个句子,并填充每个句子,使其具有批处理中最长句子的长度。

  

文件   “ /Users/liam_adams/Documents/cs512/research_project/custom/model.py”,   第34行,向前       s = self.embedding(s)#暗淡:batch_size x batch_max_len x embedding_dim文件   “ /Users/liam_adams/Documents/cs512/venv_research/lib/python3.7/site-packages/torch/nn/modules/module.py”,   第493行,在致电中       结果= self.forward(* input,** kwargs)文件“ /Users/liam_adams/Documents/cs512/venv_research/lib/python3.7/site-packages/torch/nn/modules/sparse.py”,   117行,向前       self.norm_type,self.scale_grad_by_freq,self.sparse)文件“ /Users/liam_adams/Documents/cs512/venv_research/lib/python3.7/site-packages/torch/nn/functional.py”,   嵌入中的第1506行       返回torch.embedding(重量,输入,padding_idx,scale_grad_by_freq,稀疏)RuntimeError:索引超出范围   ../ aten / src / TH / generic / THTensorEvenMoreMath.cpp:193

这里的问题是初始化嵌入的尺寸不同于批处理数组的尺寸吗?我的batch_size将保持不变,但是batch_max_len将随批次而变化。这是本教程中的操作方式。

3 个答案:

答案 0 :(得分:2)

您有些错误。请更正这些错误,然后重新运行您的代码:

  • Sub UpdateWordLinks() Dim newFilePath As Variant Dim excelDocs As Variant Dim range As Word.range Dim shape As shape Dim section As Word.section excelDocs = GetFileNamesbyExt(ThisDocument.Path, ".xlsx") 'The new file path as a string (the text to replace with)' newFilePath = ThisDocument.Path & Application.PathSeparator & excelDocs(1) Call updateFields(ThisDocument.fields, newFilePath) For Each section In ThisDocument.Sections Call updateHeaderFooterLinks(section.headers, newFilePath) Call updateHeaderFooterLinks(section.Footers, newFilePath) Next 'Update the links ThisDocument.fields.Update Set newFilePath = Nothing Set excelDocs(1) = Nothing Set excelDocs = Nothing Set range = Nothing Set shape = Nothing Set section = Nothing End Sub Function GetFileNamesbyExt(ByVal FolderPath As String, FileExt As String) As Variant Dim Result As Variant Dim i As Integer Dim MyFile As Object Dim MyFSO As Object Dim MyFolder As Object Dim MyFiles As Object Set MyFSO = CreateObject("Scripting.FileSystemObject") Set MyFolder = MyFSO.GetFolder(FolderPath) Set MyFiles = MyFolder.Files ReDim Result(1 To MyFiles.count) i = 1 For Each MyFile In MyFiles If InStr(1, MyFile.Name, FileExt) <> 0 Then Result(i) = MyFile.Name i = i + 1 End If Next MyFile ReDim Preserve Result(1 To i - 1) GetFileNamesbyExt = Result Set MyFile = Nothing Set MyFSO = Nothing Set MyFolder = Nothing Set MyFiles = Nothing End Function Function updateHeaderFooterLinks(headersFooters As headersFooters, newFilePath As Variant) Dim headerFooter As Word.headerFooter For Each headerFooter In headersFooters Call updateFields(headerFooter.range.fields, newFilePath) Next Set headerFooter = Nothing End Function Function updateFields(fields As fields, newFilePath As Variant) Dim field As field Dim oldFilePath As Variant For Each field In fields If field.Type = wdFieldLink Then oldFilePath = field.LinkFormat.SourceFullName field.Code.Text = Replace(field.Code.Text, _ Replace(oldFilePath, "\", "\\"), _ Replace(newFilePath, "\", "\\")) End If Next Set field = Nothing Set oldFilePath = Nothing End Function 是唯一令牌的总数。因此,它应该是本教程中的params['vocab_size']

  • len(vocab)可以是params['embedding_dim']50或任何您选择的值。大多数人都会使用100范围内(包括两个极端)的东西。 Word2Vec和GloVe都使用[50, 1000]尺寸嵌入词。

  • 300将接受任意批量。所以,没关系。顺便说一句,在本教程中,诸如self.embedding()之类的注释内容表示该特定操作的输出张量的形状,而不是输入。

答案 1 :(得分:0)

在这里https://discuss.pytorch.org/t/embeddings-index-out-of-range-error/12582

找到了答案

我正在将单词转换为索引,但是索引基于单词总数,而不是# makes for loop works line by line IFS=$'\n' # assuming $COMMIT is a valid commit-ish for line in $(git ls-tree $COMMIT) do filename=$(echo $line | cut -f 2) node=$(echo $line | cut -f 1) commit=$(git rev-list --max-count=1 $COMMIT $filename) echo -e "$node $commit\t$filename" done ,后者是最常用的单词的较小集合。

答案 2 :(得分:0)

nn.embedding中的

嵌入大小应为max(input_data)。检查input_data的数据类型,因为确定性必须为整数。