在python2中取消由python3保存的sklearn.ensemble.RandomForestClassifier

时间:2017-07-03 02:19:20

标签: python numpy scikit-learn pickle

我遇到与unpickle sklearn.tree.DescisionTreeRegressor in python 2 from python3类似的问题,但stackoverflow.com/questions/1385096/上的链接解决方案似乎对我不起作用。

我在尝试使用python2加载时遇到了类似的问题 sklearn我在python3中保存的RandomForestClassifier。

我把它归结为核心问题,即加载一个结构化的numpy数组,保存它以表示树中的节点,这导致 ValueError: non-string names in Numpy dtype unpickling

我创建了一个MWE。

在python 3中

import numpy as np
import pickle
data = np.array(
    [( 1, 26, 69,   5.32214928e+00,  0.69562945, 563,  908.,  1),
     ( 2,  7, 62,   1.74883020e+00,  0.33854101, 483,  780.,  1),
     (-1, -1, -2,  -2.00000000e+00,  0.76420451,   7,    9., -2),
     (-1, -1, -2,  -2.00000000e+00,  0.        ,  62,  106., -2)],
  dtype=[('left_child', '<i8'), ('right_child', '<i8'),
  ('feature', '<i8'), ('threshold', '<f8'), ('impurity',
  '<f8'), ('n_node_samples', '<i8'),
  ('weighted_n_node_samples', '<f8'), ('missing_direction',
  '<i8')])

# Save using pickle
with open('data.pkl', 'wb') as file_:
    # Use protocol 2 to support python2 and 3
    pickle.dump(data, file_, protocol=2)

# Save with numpy directly
np.save('data.npy', data)

然后在python 2中

# Load with pickle
import pickle
with open('data.pkl', 'rb') as file_:
    data = pickle.load(file_)
# This results in `ValueError: non-string names in Numpy dtype unpickling`

# Load with numpy directly
data = np.load('data.npy')
# This works, but wont help me load a pickled sklearn.ensemble.RandomForestClassifier

然而,这仍然不会让sklearn在2到3之间发挥出色。 那么,我们怎样才能让pickle正确加载这个numpy对象呢? 以下是链接中建议的修复:

from lib2to3.fixes.fix_imports import MAPPING
import sys
import pickle

# MAPPING maps Python 2 names to Python 3 names. We want this in reverse.
REVERSE_MAPPING = {}
for key, val in MAPPING.items():
    REVERSE_MAPPING[val] = key

# We can override the Unpickler and loads
class Python_3_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module in REVERSE_MAPPING:
            module = REVERSE_MAPPING[module]
        __import__(module)
        mod = sys.modules[module]
        klass = getattr(mod, name)
        return klass

with open('data.pkl', 'rb') as file_:
    data = Python_3_Unpickler(file_).load()

这仍然不起作用

0 个答案:

没有答案