遍历pytorch中的数据加载器时,在Colab中进行回溯(最近一次调用最后一次)

时间:2018-12-22 09:28:40

标签: python python-3.x deep-learning pytorch vgg-net

我正在使用一个使用pytorch的预训练模型vgg19对花卉图像进行分类的项目。

我仅依靠模型功能并使用自定义分类器。

但是,在启动for循环以将图像馈送到模型分类器并通过历时计算准确性时,我得到了一个错误。

我不确定是什么问题,因为错误是回溯(最近一次通话是最近一次)

下面是我的笔记本。

引发错误的单元格在下面

#training the classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.classifier.parameters(),lr=0.01)

steps = 0
running_loss = 0
epochs = 5
print_every = 5

for epoch in range(epochs):
    for images,labels in train_dataloader:
        steps += 1
        optimizer.zero_grad()

        logps = model.forward(images)
        loss  = criterion(logps,labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()

            with torch.no_grad():
                for images, labels in valid_dataloader:
                    logps = model.forward(images)
                    batch_loss = criterion(logps, labels)

                    test_loss += batch_loss.item()

                    #Calculate accuracy
                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(5,dim=1)
                    equals = top_class  == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

            print(f"Epoch {epoch+1}/{epochs}.."
             f"Train loss: {running_loss/print_every: .3f}.."
             f"Test loss: {test_loss/len(valid_loader):.3f}.."
             f"Test accuracy: {accuracy/len(valid_loader):.3f}")
            running_loss = 0
            model.train()

我在运行笔记本电脑时遇到的错误

AttributeError                            Traceback (most recent call last)
<ipython-input-11-c218f8f2b72e> in <module>()
      8 
      9 for epoch in range(epochs):
---> 10     for images,labels in train_dataloader:
     11         steps += 1
     12         optimizer.zero_grad()

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    312         if self.num_workers == 0:  # same-process loading
    313             indices = next(self.sample_iter)  # may raise StopIteration
--> 314             batch = self.collate_fn([self.dataset[i] for i in indices])
    315             if self.pin_memory:
    316                 batch = pin_memory_batch(batch)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in <listcomp>(.0)
    312         if self.num_workers == 0:  # same-process loading
    313             indices = next(self.sample_iter)  # may raise StopIteration
--> 314             batch = self.collate_fn([self.dataset[i] for i in indices])
    315             if self.pin_memory:
    316                 batch = pin_memory_batch(batch)

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/folder.py in __getitem__(self, index)
     99         """
    100         path, target = self.samples[index]
--> 101         sample = self.loader(path)
    102         if self.transform is not None:
    103             sample = self.transform(sample)

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/folder.py in default_loader(path)
    145         return accimage_loader(path)
    146     else:
--> 147         return pil_loader(path)
    148 
    149 

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/folder.py in pil_loader(path)
    127     # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    128     with open(path, 'rb') as f:
--> 129         img = Image.open(f)
    130         return img.convert('RGB')
    131 

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in open(fp, mode)
   2319     return True
   2320 
-> 2321 
   2322 def new(mode, size, color=0):
   2323     """

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in preinit()
    368 
    369 
--> 370 def preinit():
    371     """Explicitly load standard file format drivers."""
    372 

/usr/local/lib/python3.6/dist-packages/PIL/PpmImagePlugin.py in <module>()
    156 Image.register_save(PpmImageFile.format, _save)
    157 
--> 158 Image.register_extensions(PpmImageFile.format, [".pbm", ".pgm", ".ppm"])

AttributeError: module 'PIL.Image' has no attribute 'register_extensions'

1 个答案:

答案 0 :(得分:0)

该错误是由于Colab上已安装较旧版本的Pillow产生的干扰引起的。您需要将其升级到最新版本。使用以下代码升级到最新版本的Pillow。

!pip uninstall -y Pillow
!pip install Pillow==5.3.0
import PIL.Image

现在,只需重新启动Restart the runtime。它将消除错误。