将元组/张量列表视为单个张量

时间:2019-05-28 14:08:21

标签: pytorch

我将Pytorch用于一些机器人强化学习任务。我想同时使用图像和有关状态的信息作为对此任务的观察。我使用的实现不直接支持此功能,因此我要进行一些修改。预期的观察结果要么是状态,即一维张量,要么是图像,即3维张量(通道,宽度,高度)。在我的任务中,我希望观察到的是张量元组。

在我的代码库中的许多地方,观察结果当然应该是单个张量,而不是张量的元组。 是否有一种简单的方法将一堆张量视为一个张量?

例如,我想要:

observation.to(device)

observation是一个张量时正常工作,并且在.to(device)是张量的元组时在每个张量上调用observation

创建支持该数据的数据类型应该足够简单,但是我想知道这样的数据类型是否已经存在?到目前为止我还没有发现任何东西。

1 个答案:

答案 0 :(得分:1)

如果张量大小都相同,则可以使用torch.stack将它们连接成一个具有一维的张量。

示例:

>>> import torch
>>> a=torch.randn(2,1)
>>> b=torch.randn(2,1)
>>> c=torch.randn(2,1)
>>> a
tensor([[ 0.7691],
        [-0.0297]])
>>> b
tensor([[ 0.4844],
        [-0.9142]])
>>> c
tensor([[ 0.0210],
        [-1.1543]])
>>> torch.stack((a,b,c))
tensor([[[ 0.7691],
         [-0.0297]],

        [[ 0.4844],
         [-0.9142]],

        [[ 0.0210],
         [-1.1543]]])

然后您可以使用torch.unbind转到另一个方向。