遍历批处理图像加载器pytorch

时间:2019-12-20 03:59:13

标签: numpy pytorch

假设我与

  

imgs = torch.Size([128,1,28,28])

如果我想遍历每个图像

for img in imgs:
   print(img.shpae) -> torch.Size([1, 28, 28])

如果我想为每个图像获取torch.Size([1,1,28,28]),该怎么办?

2 个答案:

答案 0 :(得分:1)

您可以初始将张量调整为[128, 1, 1, 28, 28]

的形状
# tensor.resize_((`new_shape`))    
imgs.resize_((128, 1, 1, 28, 28))

当您遍历每张图像时,否将为所需的形状[1、1、28、28]。

第二,如果您不想更改原始数据,请分别调整每个图像的形状

# tensor.resize_((`new_shape`))    
img.resize_((1, 1, 28, 28))

看看PyTorch documentation

答案 1 :(得分:1)

unsqueeze只需变暗,您要在该位置添加一个额外的单例尺寸。

imgs = torch.zeros([128, 1, 28, 28])

# dim (int) – the index at which to insert the singleton dimension
imgs.unsqueeze_(dim = 1)

imgs.shape
>>> torch.Size([128, 1, 1, 28, 28])