火炬,如何检查变量是否是CUDA?

时间:2016-08-30 10:45:24

标签: lua torch

我发现像type()这样的函数来识别哪个变量是CudaTensor或Normal。

require('cutorch')

x = torch.Tensor(3,3)
x = x:cuda()

if type(x) == 'CudaTensor' then -- What function should be used?
    print('x is CUDA tensor')
else
    print('x is normal tensor')
end

1 个答案:

答案 0 :(得分:2)

使用:type()张量方法:

cutorch = require('cutorch')

x = torch.Tensor(3,3)
x = x:cuda()

if x:type() == 'torch.CudaTensor' then
    print('x is CUDA tensor')
else
    print('x is normal tensor')
end