我发现像type()这样的函数来识别哪个变量是CudaTensor或Normal。
require('cutorch')
x = torch.Tensor(3,3)
x = x:cuda()
if type(x) == 'CudaTensor' then -- What function should be used?
print('x is CUDA tensor')
else
print('x is normal tensor')
end
答案 0 :(得分:2)
使用:type()
张量方法:
cutorch = require('cutorch')
x = torch.Tensor(3,3)
x = x:cuda()
if x:type() == 'torch.CudaTensor' then
print('x is CUDA tensor')
else
print('x is normal tensor')
end