如何将4维PyTorch张量乘以1维张量?

时间:2020-05-10 13:49:31

标签: python pytorch torch

我正在尝试编写用于混合训练的函数。在此site上,我找到了一些代码,并适应了以前的代码。但是在原始代码中,批次(64)仅生成一个随机变量。但是我想为每张照片批量分配随机值。 具有一个批处理变量的代码:

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    mixed_y = lam * y + (1 - lam) * y[index,:]

    return mixed_x, mixed_y
输入的

x和y来自pytorch DataLoader。 x输入大小:torch.Size([64, 3, 256, 256]) y输入大小:torch.Size([64, 3474])

此代码运行良好。然后我将其更改为:

def mixup_data(x, y):
    batch_size = x.size()[0]
    lam = torch.rand(batch_size)
    index = torch.randperm(batch_size)

    mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]

    return mixed_x, mixed_y

但是它给出了一个错误:RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

我如何理解代码的工作方式是批量获取第一张图像并乘以lam张量(长64个值)中的第一个值。我该怎么办?

2 个答案:

答案 0 :(得分:1)

您需要替换以下行:

lam = torch.rand(batch_size)

作者

lam = torch.rand(batch_size, 1, 1, 1)

在您当前的代码中,lam[index] * x的乘法是不可能的,因为lam[index]的大小为torch.Size([64]),而x的大小为torch.Size([64, 3, 256, 256])。因此,您需要将lam[index]的大小设置为torch.Size([64, 1, 1, 1]),以便它可以广播。

为应对以下声明:

mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

我们可以在语句前重塑lam张量。

lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

答案 1 :(得分:0)

问题在于两个相乘的张量的大小不匹配。让我们以lam[index] * x为例。大小如下:

  • xtorch.Size([64, 3, 256, 256])
  • lam[index]torch.Size([64])

为了将它们相乘,它们应该具有相同的大小,其中lam[index]对每批[3, 256, 256]使用相同的值,因为您想将该批次中的每个元素都乘以相同的值,但每个批次都不同。

lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])

.expand_as(x)重复奇异尺寸,使其具有与x相同的大小,有关详细信息,请参见.expand() documentation

您不需要扩展张量,因为如果存在奇异尺寸,PyTorch会自动为您完成。这就是所谓的广播:PyTorch - Broadcasting Semantics。因此,将大小torch.Size([64, 1, 1, 1])x相乘就足够了。

lam[index].view(batch_size, 1, 1, 1) * x

y同样适用,但大小为torch.Size([64, 1]),因为y的大小为torch.Size([64, 3474])

mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]

仅需注意一点,lam[index]仅重新排列lam的元素,但是由于您是随机创建的,因此无论您是否重新排列都无济于事。唯一重要的是xy就像原始代码一样被重新排列。