如何将`sklearn.neighbors.KDTree`对象转储到hdf5?

时间:2016-07-18 07:22:01

标签: scikit-learn pickle hdf5 kdtree

根据sklearn.neighbors.KDTree的{​​{3}},我们可以使用pickle将KDTree对象转储到磁盘。 但是,转储和加载以及存储都很慢。

是否可以将其转储为hdf5格式?

1 个答案:

答案 0 :(得分:2)

您可以使用__getstate____getstate__。大多数内部量都是类型数组或标量,因此适用于hdf5。 还有一些工作要做,因为__getstate__返回的最后一个术语是一个函数,我们可以使用pickle.dumps为hdf5 stroage轻松地将其转换为字符串。

如果您觉得这很有趣,可以找到KDTree the documentation的源代码,并按__getstate__检查返回的值。

from sklearn.neighbors import KDTree
import h5py
import pickle

"""
You may find the source code of KDTree from link below
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/binary_tree.pxi
"""

__all__ = ["KDTreeH5"]


class KDTreeH5(KDTree):
    def dump(self, file):
        """
        file: str or HDF group
        """
        if not isinstance(file, h5py.Group):
            file = h5py.File(file)

        state = list(self.__getstate__())
        assert len(state) == 12
        # convert dist_metric to string for hdf5 storage.
        state[-1] = pickle.dumps(state[-1])
        for i, v in enumerate(state):
            file[str(i)] = v

    @classmethod
    def load(cls, file):
        """
        file: str or HDF group
        """
        if not isinstance(file, h5py.Group):
            file = h5py.File(file, 'r')

        state = [None] * 12
        for i in range(12):
            state[i] = file[str(i)].value
        # recover dist_metric from string.
        state[-1] = pickle.loads(state[-1])

        obj = cls.__new__(cls)
        obj.__setstate__(state)
        return obj