我将Pytorch用于一些机器人强化学习任务。我想同时使用图像和有关状态的信息作为对此任务的观察。我使用的实现不直接支持此功能,因此我要进行一些修改。预期的观察结果要么是状态,即一维张量,要么是图像,即3维张量(通道,宽度,高度)。在我的任务中,我希望观察到的是张量元组。
在我的代码库中的许多地方,观察结果当然应该是单个张量,而不是张量的元组。 是否有一种简单的方法将一堆张量视为一个张量?
例如,我想要:
observation.to(device)
在observation
是一个张量时正常工作,并且在.to(device)
是张量的元组时在每个张量上调用observation
。
创建支持该数据的数据类型应该足够简单,但是我想知道这样的数据类型是否已经存在?到目前为止我还没有发现任何东西。
答案 0 :(得分:1)
如果张量大小都相同,则可以使用torch.stack将它们连接成一个具有一维的张量。
示例:
>>> import torch
>>> a=torch.randn(2,1)
>>> b=torch.randn(2,1)
>>> c=torch.randn(2,1)
>>> a
tensor([[ 0.7691],
[-0.0297]])
>>> b
tensor([[ 0.4844],
[-0.9142]])
>>> c
tensor([[ 0.0210],
[-1.1543]])
>>> torch.stack((a,b,c))
tensor([[[ 0.7691],
[-0.0297]],
[[ 0.4844],
[-0.9142]],
[[ 0.0210],
[-1.1543]]])
然后您可以使用torch.unbind转到另一个方向。