PyTorch用于什么变换?

时间:2018-04-24 13:08:56

标签: image input transformation pytorch

我是Pytorch的新手,而不是CNN的专家。 我已经使用他们提供的Tutorial Pytorch教程完成了一个成功的分类器,但是在加载数据时我真的不明白我在做什么。

他们为训练做了一些数据扩充和规范化,但是当我尝试修改参数时,代码不起作用。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

我是否在扩展训练数据集?我没有看到数据增加。

为什么我修改了transforms.RandomResizedCrop(224)的值,数据加载停止工作?

我是否还需要转换测试数据集?

我对他们所做的数据转换感到有点困惑。

2 个答案:

答案 0 :(得分:14)

transforms.Compose只是提供给它的所有变换。因此,transforms.Compose中的所有变换都将逐个应用于输入。

列车转换

  1. transforms.RandomResizedCrop(224):这会随机从输入图像中提取大小为(224, 224)的小块。因此,它可能会选择这条路径来自topleft,bottomright或其间的任何位置。因此,您正在进行此部分的数据扩充。此外,更改此值对于模型中完全连接的图层不会很好,因此不建议更改此值。
  2. transforms.RandomHorizontalFlip():我们的图片大小为(224, 224)后,我们可以选择翻转它。这是数据扩充的另一部分。
  3. transforms.ToTensor():这只是将您的输入图像转换为PyTorch张量。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):这只是输入数据缩放,必须为数据集预先计算这些值(mean和std)。也不建议更改这些值。
  5. 验证转换

    1. transforms.Resize(256):首先将输入图片的大小调整为(256, 256)
    2. transforms.CentreCrop(224):裁剪形状(224, 224)
    3. 图像的中心部分

      休息与火车相同

      P.S。:您可以在official docs

      中详细了解这些转换

答案 1 :(得分:1)

对于数据扩充方面的歧义,我将向您介绍以下答案:

Data Augmentation in PyTorch

但是简而言之,假设您只有随机的水平翻转变换,当您遍历图像数据集时,有的返回为原始,有的返回为翻转(不返回翻转后的原始图像)。换句话说,一次迭代中返回的图像数量与数据集的原始大小相同,并且不会增加。