如何在PyTorch中组合/堆叠张量并组合尺寸?

时间:2019-02-12 04:58:16

标签: python stack concatenation pytorch tensor

我需要将代表[1,84,84]大小的代表灰度图像的4个张量组合成[4,84,84]形状的堆叠,代表四个灰度图像,每个图像都表示为“通道”张量样式CxWxH。

我正在使用PyTorch。

我曾经尝试过使用torch.stack和torch.cat,但是如果其中之一是解决方案,那么我没有运气找出正确的准备/方法来获得结果。

谢谢您的帮助。

import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        #t =  torch.cat(tuple(self.phi),dim=0)
        t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
            self.phi.append(torch.tensor([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

上面的代码返回:

print(a.phi)

deque([tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84])], maxlen=4)

print(s, s.shape)

tensor([[ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84]]) torch.Size([4, 3])

但是我想返回的就是[4,84,84]。我怀疑这很简单,但却在逃避我。

1 个答案:

答案 0 :(得分:0)

似乎您误解了torch.tensor([1, 84, 84])在做什么。让我们看一下:

torch.tensor([1, 84, 84])
print(x, x.shape) #tensor([ 1, 84, 84]) torch.Size([3])

您可以从上面的示例中看到,它为您提供了只有一个维度的张量。

从问题陈述中,您需要一个形状为[1,84,84]的张量。 外观如下:

from collections import deque
import torch
import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        t =  torch.cat(tuple(self.phi),dim=0)
#         t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
#             self.phi.append(torch.tensor([1,nS[1], nS[2]]))
            self.phi.append(torch.zeros([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

请注意,torch.cat给出了形状为[4,84,84]的张量,torch.stack给出了形状为[4,1,84,84]的张量。它们的区别可以在What's the difference between torch.stack() and torch.cat() functions?

中找到