从生成器创建火炬张量

时间:2019-03-22 20:31:21

标签: pytorch

我尝试从生成器构造张量,如下所示:

>>> torch.tensor(i**2 for i in range(10))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Could not infer dtype of generator

目前我只是这样做:

>>> torch.tensor([i**2 for i in range(10)])
tensor([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81])

是否有避免这种中间列​​表的方法?

2 个答案:

答案 0 :(得分:1)

我不明白为什么要使用发电机。该列表在这里并没有真正改变。

问题是:您是否要先在 Python 中创建数据,然后再移动,然后再将其创建为 PyTorch (大多数情况下速度较慢) 是否要在 PyTorch 中直接创建
(生成器总是会首先在Python中创建数据)

因此,如果您要加载数据,情况会有所不同,但是如果您要生成数据,我认为没有理由不应该在直接使用PyTorch


如果您想直接在PyTorch中创建示例列表,可以使用arangepow

torch.arange(10).pow(2)

输出:

tensor([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81])

torch.arange(10)的工作方式与python中的range相同,因此它的用途与range完全一样。然后pow(2)只是将张量带到2次方。

但是一旦使用pow创建了张量,您也可以执行其他所有类型的计算,而不是arange

答案 1 :(得分:1)

正如@ blue-phoenox指出的那样,最好使用内置的PyTorch函数直接创建张量。但是,如果必须处理生成器,建议使用numpy作为中间阶段。由于PyTorch避免复制numpy数组,因此它应该表现出色(与简单的列表理解相比)

>>> import torch
>>> import numpy as np
>>> torch.from_numpy(np.fromiter((i**2 for i in range(10)), int))
tensor([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81])