Pytorch - torchvision.dataset.ImageFolder的子类 - 导入错误

时间:2018-06-12 13:03:54

标签: python python-import pytorch

在我的上一个post之后,我现在正在尝试实现torchvision.datasets.ImageFolder类的子类。以下代码返回错误("name 'default_loader' is not defined"),我无法弄清楚原因。你能帮帮我吗?

class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform=None, target_transform=None,loader=default_loader):
       super().__init__(root,transform,target_transform,loader)

当我删除“None”和“default_loader”时,将其写为这样;

    class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform, target_transform,loader):
       super().__init__(root,transform,target_transform,loader)

在尝试创建此类的实例时,我收到错误的输入参数错误,例如:

JJ=ExtendingImageFolder(root='C:/',transform=transform)

我在这里做错了什么?

提前致谢!

1 个答案:

答案 0 :(得分:2)

default_loader()torchvision/datasets/folder.py中定义的函数,ImageFolder和其他基于文件夹的数据集帮助程序。

但不会导出torchvision/datasets/__init__.py(与ImageFolder不同)。您仍然可以使用" from torchvision.datasets.folder import default_loader"直接导入它。 - 这应解决您的导入错误。