RuntimeError:后端CUDA的预期对象,但参数#3'索引'的后端CPU

时间:2019-11-22 12:09:03

标签: python nlp

    LSTM(
      (embed): Embedding(139948, 12, padding_idx=0)
      (lstm): LSTM(12, 12, num_layers=2, batch_first=True, bidirectional=True)
      (lin): Linear(in_features=240, out_features=6, bias=True)
    )
    Train epoch : 1,  loss : 771.319284286499,  accuracy :0.590
    =================================================================================================
    Traceback (most recent call last):enter code here
      File "C:/Users/Administrator/PycharmProjects/untitled/example.py", line 297, in <module>
        scores = model(x_test, x_test_seq_length)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "C:/Users/Administrator/PycharmProjects/untitled/example.py", line 141, in forward
        x = self.embed(x)  # sequence_length(max_len), batch_size, embed_size
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py", line 117, in forward
        self.norm_type, self.scale_grad_by_freq, self.sparse)
      File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1506, in embedding
        return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    RuntimeError: Expected object of backend CUDA but got backend CPU for argument #3 'index'

它在训练集上工作正常,但我在测试集中不断出现该错误。我已经思考了10个小时。

出什么问题了?

2 个答案:

答案 0 :(得分:1)

似乎您的程序希望使用GPU运行,但正在CPU上运行。确保正确设置了程序的GPU设置,并且所使用的CUDA版本是最新的。

您可以在此处找到有关此的更多信息(假设您正在使用tensorflow): https://www.tensorflow.org/install/gpu

答案 1 :(得分:0)

如果您在CPU上训练了模型,那么将以某种方式加载测试数据并将其转换为CUDA数据类型。因此,您可以通过将输入张量移动到CPU设备来解决此问题。并同时移动模型(不会造成伤害)。

可以这样做:

>>> import torch
>>> device = torch.device("cpu")
>>> # move the model
>>> model = model.to(device)
>>> # move any input tensors
>>> test_data = test_data.to(device)