在pytorch中加载数据时出错:'Image'对象没有属性'shape'

时间:2017-09-29 06:15:50

标签: python image-processing deep-learning pytorch

我使用基于ImageNet training in PyTorch的代码调整resnet152,并且在加载数据时发生错误,并且仅在处理了几批图像后才发生错误。我该如何解决这个问题。 以下代码是产生相同错误的简化代码:

# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(train_img_dir, transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=256, shuffle=True,
    num_workers=1, pin_memory=True)
for i, (input_x, target) in enumerate(train_loader):
    if i % 10 == 0:
        print(i)
        print(input_x.shape)
        print(target.shape)

错误

0
torch.Size([256, 3, 224, 224])
torch.Size([256])
10
torch.Size([256, 3, 224, 224])
torch.Size([256])
20
torch.Size([256, 3, 224, 224])
torch.Size([256])
30
torch.Size([256, 3, 224, 224])
torch.Size([256])
----------------------------------------------------------------------
AttributeError                       Traceback (most recent call last)
<ipython-input-48-792d6ca206df> in <module>()
----> 1 for i, (input_x, target) in enumerate(train_loader):
      2     if i % 10 == 0:
      3 #     sample_img = input_x[0]
      4         print(i)
      5         print(input_x.shape)

/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    200                 self.reorder_dict[idx] = batch
    201                 continue
--> 202             return self._process_next_batch(batch)
    203 
    204     next = __next__  # Python 2 compatibility

/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
    220         self._put_indices()
    221         if isinstance(batch, ExceptionWrapper):
--> 222             raise batch.exc_type(batch.exc_msg)
    223         return batch
    224 

AttributeError: Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/datasets/folder.py", line 118, in __getitem__
    img = self.transform(img)
  File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 369, in __call__
    img = t(img)
  File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 706, in __call__
    i, j, h, w = self.get_params(img)
  File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 693, in get_params
    w = min(img.size[0], img.shape[1])
AttributeError: 'Image' object has no attribute 'shape'

1 个答案:

答案 0 :(得分:4)

$stmt = $db_conn->prepare($query); $stmt->execute($my_array); 中存在错误。在错误消息的最后一行,它应该是transforms.RandomSizedCrop.get_params()而不是img.size

只有当裁剪连续10次失败(它回落到中央裁剪)时,才会执行包含错误的行。这就是每批图像都不会出现此错误的原因。

我已经提交了PR以修复它。要快速解决问题,您可以修改img.shape文件并将所有/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py更改为img.shape

编辑: PR已合并。您可以在GitHub上安装最新的img.size来修复它。