将python列表转换为pytorch张量

时间:2020-02-06 07:55:41

标签: python pytorch

将python的数字列表转换为pytorch Tensor时遇到问题: 这是我的代码:

caption_feat = [int(x)  if x < 11660  else 3 for x in caption_feat]

打印caption_feat给出:[1, 9903, 7876, 9971, 2770, 2435, 10441, 9370, 2]
我进行这样的转换:tmp2 = torch.Tensor(caption_feat) 现在打印tmp2可以得到:tensor([1.0000e+00, 9.9030e+03, 7.8760e+03, 9.9710e+03, 2.7700e+03, 2.4350e+03, 1.0441e+04, 9.3700e+03, 2.0000e+00])
但是我希望得到:tensor([1. , 9903, , 9971. ......]) 有想法吗?

4 个答案:

答案 0 :(得分:1)

您可以通过定义 list 直接将 python Tensor 转换为 pytorch dtype。例如,

import torch

a_list = [3,23,53,32,53] 
a_tensor = torch.Tensor(a_list)
print(a_tensor.int())

>>> tensor([3,23,53,32,53])

答案 1 :(得分:0)

如果所有元素都是整数,则可以通过定义dtype

来制作整数割炬张量
>>> a_list = [1, 9903, 7876, 9971, 2770, 2435, 10441, 9370, 2]
>>> tmp2 = torch.tensor(a_list, dtype=torch.int)
>>> tmp2
tensor([    1,  9903,  7876,  9971,  2770,  2435, 10441,  9370,     2],
       dtype=torch.int32)

torch.Tensor返回torch.float32时,它会以科学计数法打印数字

>>> tmp2 = torch.Tensor(a_list)
>>> tmp2
tensor([1.0000e+00, 9.9030e+03, 7.8760e+03, 9.9710e+03, 2.7700e+03, 2.4350e+03,
        1.0441e+04, 9.3700e+03, 2.0000e+00])
>>> tmp2.dtype
torch.float32

答案 2 :(得分:0)

尝试

torch.IntTensor(caption_feat)

您可以在https://pytorch.org/docs/stable/tensors.html

上看到其他类型

答案 3 :(得分:0)

一个简单的方法是将列表转换为numpy数组,指定所需的dtype并在新数组上调用torch.from_numpy

玩具示例:

some_list = [1, 10, 100, 9999, 99999]
tensor = torch.from_numpy(np.array(some_list, dtype=np.int))

其他建议的另一个选项是在创建张量时指定类型:

torch.tensor(some_list, dtype=torch.int)

两者都应该正常工作。