从 pytorch 数据集返回索引:更改 __getitem__ 的函数导致元类冲突

时间:2021-04-27 08:58:29

标签: python python-3.x pytorch metaclass

我有多个类(用于不同的数据集)从 pytorch 的 Dataset 类继承。它们有一个通用的结构,如下所示:

from torch.utils.data import Dataset

class SomeDataset(Dataset):

    def __init__(self, data, labels):
        super(SomeDataset, self).__init__()
        self.data = data
        self.labels = labels
        self.__name__ = 'SomeDataset'

    def __getitem__(self, index):
        return {'data': self.data[index], 'label': self.labels[index]}

    def __len__(self):
        return len(data)

最近我意识到在批处理时跟踪传递到 Dataloader 的标签是有益的,所以在谷歌搜索如何做到这一点时,我遇到了 this thread,这是我修改代码的地方编写这个函数:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return {'index':1, **dataset_class.__getitem__(self, index)}

    return type(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

我以前从未见过 type 像这样使用,但经过一些谷歌搜索后,它一些有意义,所以我尝试了它。不幸的是,这导致了这个错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases

这导致了更多的谷歌搜索,即使我开始了解元类是什么以及它们是如何使用的,我仍然无法弄清楚这种方法有什么问题或如何解决它-我开始认为将这个功能重写到我的数据集类中可能会更容易,而不是使用一些简洁的包装器来为我做这件事。任何人都可以衡量我缺少的东西吗?

1 个答案:

答案 0 :(得分:0)

就这样做:

def return_indices(dataset_class):
    
    def __getitem__(self, index):
        return {'index':1, **dataset_class.__getitem__(self, index)}
    metacls = type(dataset_class)
    return metacls(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

发生了什么:如您所见,对 type 的 3 参数调用是在 Python 中以编程方式创建新类的方法,无需“class”语句及其主体。

但是 type 是“基本元类”——虽然它的实例将是普通类,但它也会将你正在创建的类的元类“硬编码”到自身——相反,使用 {{1} } 语句将使 Python 在您创建的类的基类中搜索合适的元类。

仅使用您的派生类元类(通过单参数形式的类型获得,如上,或通过类的 class 属性获得,如在 __class__ 中)。

使用 this 作为可调用对象代替 type 会将其自身作为元类,事情应该可以正常工作。

注意:由于有更多的元类机制,例如 dataset_class.__class__,仅调用元类而不是 __prepare__ 并不总是有效 - 正确的通用方法这样做涉及调用 typetypes.prepare_class 并有一个回调来执行在类语句主体中发生的类主体的等效执行。大多数情况下不需要。

相关问题