当我使用torch.nn.utils.rnn.pack_padded_sequence()
时,发生了错误。
这是我的代码:
import torch
import numpy as np
x = torch.from_numpy(np.array([[1,2,3,4,5,6,0,0],[6,7,8,9,0,0,0,0],[12,83,84,0,0,0,0,0]]))
length =[6,4,3]
print(torch.nn.utils.rnn.pack_padded_sequence(input=x, lengths=length, batch_first=True))
错误如下:
Traceback (most recent call last):
File "/home/pc/PycharmProjects/padded/padded.py", line 112, in <module>
print(torch.nn.utils.rnn.pack_padded_sequence(input=x, lengths=length, batch_first=True))
File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 67, in wrapper
if not might_trace(args):
File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 141, in might_trace
first_arg = args[0]
IndexError: tuple index out of range
但是当我输入以下代码时:
import torch
import numpy as np
x = torch.from_numpy(np.array([[1,2,3,4,5,6,0,0],[6,7,8,9,0,0,0,0],[12,83,84,0,0,0,0,0]]))
length =[6,4,3]
print(torch.nn.utils.rnn.pack_padded_sequence(x, lengths=length, batch_first=True))
结果正常。而且我不知道为什么。
您能帮我解决问题吗?
答案 0 :(得分:0)
这听起来很愚蠢,但我认为这是由于包装器/低级翻译器函数在PyTorch函数上的运行方式所致:
根据我对python的*args, **kwargs
装饰器的了解(请参阅更多here),问题在于,只有在不使用=
< / em>。
意思是,它们存储在关键字/值对字典中。相反,如果我们查看错误消息的相关部分(这一部分:File "/home/pc/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py"
),我们可以看到以下内容:
first_arg = args[0]
表示它想使用索引来解决这个问题。不幸的是,由于关键字仅在键/值字典中存储并传递,因此我们无法使用索引解决此问题,随后会引发错误。
到目前为止,我还无法找到解决该问题的方法,因为它是非常抽象的,并且(可能)被许多不同的类使用,特别是因为它被传递通过了中间层(在第67行,请参见堆栈)跟踪)。