如何在pytorch图像处理模型中处理具有多个图像的样本?

时间:2020-11-11 14:21:46

标签: pytorch tensor mini-batch

我的模型训练涉及对同一图像的多个变体进行编码,然后对图像的所有变体所产生的表示求和。

数据加载器产生张量批次[batch_size,num_variants,1,height,width]1对应于图像颜色通道。

如何在pytorch中使用迷你批次训练模型? 我正在寻找通过网络转发所有batch_size×num_variant图像并汇总所有变体组的结果的正确方法。

我当前的解决方案包括展平前两个维度并进行for循环以汇总表示形式,但是我觉得应该有更好的方法,并且我不确定渐变是否会记住所有内容。

1 个答案:

答案 0 :(得分:1)

不确定我是否正确理解了您,但是我想这就是您想要的(比如说批处理图像张量称为image):

Nb, Nv, inC, inH, inW = image.shape

# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv, inC, inH, inW)

output = model(image)
_, outC, outH, outW = output.shape[1]

# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb, Nv, outC, outH, outW)

# summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
output = output.sum(dim=1, keepdim=False)

在输入和输出通道/大小不同的情况下,我使用了inCoutCinH等。