在创建自己的数据集以馈入pytorch DataLoader的情况下, 我设计了一种以编程方式从类继承的方法,因此从根本上扩展了将用作数据集的类,以便为其添加“自定义”功能。动态扩展效果很好。但是,PyTorch不喜欢它,当我开始基于它迭代DataLoader时,它会抱怨。
这是扩展类的玩具示例:
# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.first = 1
self.second = 2
def __len__(self):
return 1000
def __getitem__(self, item):
return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset
def extend_class(base_class):
class B(base_class):
def hello(self):
print('Yo!')
return B
if __name__ == '__main__':
a = MyDataset()
dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader)
first, second = next(iterator) # this works ok
extended_class = extend_class(MyDataset)
b = extended_class()
b.hello() # this works!
dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
first, second = next(iterator)
任何解决此问题的方法都值得赞赏!