PyTorch扩展数据集:无法腌制本地对象

时间:2020-06-26 12:36:07

标签: python oop object pytorch

在创建自己的数据集以馈入pytorch DataLoader的情况下, 我设计了一种以编程方式从类继承的方法,因此从根本上扩展了将用作数据集的类,以便为其添加“自定义”功能。动态扩展效果很好。但是,PyTorch不喜欢它,当我开始基于它迭代DataLoader时,它会抱怨。

这是扩展类的玩具示例:

# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self):
        self.first = 1
        self.second = 2

    def __len__(self):
        return 1000

    def __getitem__(self, item):
        return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset


def extend_class(base_class):
    class B(base_class):
        def hello(self):
            print('Yo!')

    return B


if __name__ == '__main__':
    a = MyDataset()
    dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)
    first, second = next(iterator)  # this works ok

    extended_class = extend_class(MyDataset)

    b = extended_class()
    b.hello() # this works!

    dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)  # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'

    first, second = next(iterator)

任何解决此问题的方法都值得赞赏!

0 个答案:

没有答案