关于rnn.pack_padded_sequence的用法

时间:2018-10-24 13:15:39

标签: python-3.x pytorch

当我使用torch.nn.utils.rnn.pack_padded_sequence()时,发生了错误。 这是我的代码:

import torch
import numpy as np
x = torch.from_numpy(np.array([[1,2,3,4,5,6,0,0],[6,7,8,9,0,0,0,0],[12,83,84,0,0,0,0,0]]))
length =[6,4,3]
print(torch.nn.utils.rnn.pack_padded_sequence(input=x, lengths=length, batch_first=True))

错误如下:

Traceback (most recent call last):
  File "/home/pc/PycharmProjects/padded/padded.py", line 112, in <module>
    print(torch.nn.utils.rnn.pack_padded_sequence(input=x, lengths=length, batch_first=True))
  File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 67, in wrapper
    if not might_trace(args):
  File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 141, in might_trace
    first_arg = args[0]
IndexError: tuple index out of range

但是当我输入以下代码时:

import torch
import numpy as np
x = torch.from_numpy(np.array([[1,2,3,4,5,6,0,0],[6,7,8,9,0,0,0,0],[12,83,84,0,0,0,0,0]]))
length =[6,4,3]
print(torch.nn.utils.rnn.pack_padded_sequence(x, lengths=length, batch_first=True))

结果正常。而且我不知道为什么。

您能帮我解决问题吗?

1 个答案:

答案 0 :(得分:0)

这听起来很愚蠢,但我认为这是由于包装器/低级翻译器函数在PyTorch函数上的运行方式所致:

根据我对python的*args, **kwargs装饰器的了解(请参阅更多here),问题在于,只有在不使用= < / em>。

意思是,它们存储在关键字/值对字典中。相反,如果我们查看错误消息的相关部分(这一部分:File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py"),我们可以看到以下内容:

first_arg = args[0]

表示它想使用索引来解决这个问题。不幸的是,由于关键字仅在键/值字典中存储并传递,因此我们无法使用索引解决此问题,随后会引发错误。

到目前为止,我还无法找到解决该问题的方法,因为它是非常抽象的,并且(可能)被许多不同的类使用,特别是因为它被传递通过了中间层(在第67行,请参见堆栈)跟踪)。