我正在尝试编写用于混合训练的函数。在此site上,我找到了一些代码,并适应了以前的代码。但是在原始代码中,批次(64)仅生成一个随机变量。但是我想为每张照片批量分配随机值。 具有一个批处理变量的代码:
def mixup_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index,:]
mixed_y = lam * y + (1 - lam) * y[index,:]
return mixed_x, mixed_y
输入的 x和y来自pytorch DataLoader。
x输入大小:torch.Size([64, 3, 256, 256])
y输入大小:torch.Size([64, 3474])
此代码运行良好。然后我将其更改为:
def mixup_data(x, y):
batch_size = x.size()[0]
lam = torch.rand(batch_size)
index = torch.randperm(batch_size)
mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]
return mixed_x, mixed_y
但是它给出了一个错误:RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3
我如何理解代码的工作方式是批量获取第一张图像并乘以lam
张量(长64个值)中的第一个值。我该怎么办?
答案 0 :(得分:1)
您需要替换以下行:
lam = torch.rand(batch_size)
作者
lam = torch.rand(batch_size, 1, 1, 1)
在您当前的代码中,lam[index] * x
的乘法是不可能的,因为lam[index]
的大小为torch.Size([64])
,而x
的大小为torch.Size([64, 3, 256, 256])
。因此,您需要将lam[index]
的大小设置为torch.Size([64, 1, 1, 1])
,以便它可以广播。
为应对以下声明:
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
我们可以在语句前重塑lam
张量。
lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
答案 1 :(得分:0)
问题在于两个相乘的张量的大小不匹配。让我们以lam[index] * x
为例。大小如下:
x
:torch.Size([64, 3, 256, 256])
lam[index]
:torch.Size([64])
为了将它们相乘,它们应该具有相同的大小,其中lam[index]
对每批[3, 256, 256]
使用相同的值,因为您想将该批次中的每个元素都乘以相同的值,但每个批次都不同。
lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])
.expand_as(x)
重复奇异尺寸,使其具有与x相同的大小,有关详细信息,请参见.expand()
documentation。
您不需要扩展张量,因为如果存在奇异尺寸,PyTorch会自动为您完成。这就是所谓的广播:PyTorch - Broadcasting Semantics。因此,将大小torch.Size([64, 1, 1, 1])
与x
相乘就足够了。
lam[index].view(batch_size, 1, 1, 1) * x
y
同样适用,但大小为torch.Size([64, 1])
,因为y
的大小为torch.Size([64, 3474])
。
mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]
仅需注意一点,lam[index]
仅重新排列lam
的元素,但是由于您是随机创建的,因此无论您是否重新排列都无济于事。唯一重要的是x
和y
就像原始代码一样被重新排列。