pytorch如何从张量中删除cuda()

时间:2018-08-03 01:36:11

标签: python type-conversion pytorch tensor

我得到了TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor)

如何将torch.cuda.FloatTensor转换为torch.LongTensor

  Traceback (most recent call last):
  File "train_v2.py", line 110, in <module>
    main()
  File "train_v2.py", line 81, in main
    model.update(batch)
  File "/home/Desktop/squad_vteam/src/model.py", line 131, in update
    loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y)
  File "/home/Desktop/squad_vteam/src/model.py", line 94, in adversarial_loss
    adv_embedding = torch.LongTensor(adv_embedding)
TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor)

3 个答案:

答案 0 :(得分:5)

您有一个浮点张量-keep class android.support.v7.** { *; } 并想将其转换为long,您执行f

您有long_tensor = f.long()张量,即数据在gpu上,并且要将其移动到cpu上,可以执行cuda

因此要将torch.cuda.Float张量cuda_tensor.cpu()转换为torch.long做A

答案 1 :(得分:3)

Pytorch 0.4.0的最佳实践是编写device agnostic code:也就是说,除了使用.cuda().cpu()之外,您还可以简单地使用.to(torch.device("cpu"))

A = A.to(dtype=torch.long, device=torch.device("cpu"))

请注意,.to()不是“就地”操作(例如,参见this answer),因此您需要将A.to(...)分配回A中。 / p>

答案 2 :(得分:1)

如果您有张量t

t = t.cpu() 

将是旧的方式。

t = t.to("cpu")

将是新的API。