检查预期对象是后端CUDA还是CPU?

时间:2019-04-23 07:29:03

标签: pytorch

  • 我正在尝试在CPU和CUDA上运行代码。
  • 创建对象时会出现问题,因为我需要知道预期的结果。

在创建计算机之前,我需要确定计算机是否需要CUDA或CPU张量。

代码:

def initilize(self, input):
     self.x = torch.nn.Parameter(torch.zeros((1,M))

def run(self,x,state):
  B = torch.cat((self.x,h)

这将输出: 错误:“后端CUDA的预期对象,但参数#1拥有后端CPU”

代码提示:

def initilize(self, input):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  if (expecting_cuda == True):
     self.x = torch.nn.Parameter(torch.zeros((1,M)).to(device))
  else
     self.x = torch.nn.Parameter(torch.zeros((1,M))

def run(self,h):
  B = torch.cat((self.x,h)
  • 问题: 如何弄清楚计算机的期望?

  • 限制: 我正在执行预定义的“检查”过程,因此无法将包含有关CUDA或CPU的信息的参数发送到函数“初始化”中。

1 个答案:

答案 0 :(得分:1)

您可以只使用self.x = torch.nn.Parameter(torch.zeros((1,M)).to(device)),而无需使用if (expecting_cuda == True):,因为to(device)也适用于CPU。