使用joblib加载的sklearn模型时出错。 TypeError:根据规则“安全”,无法将数组数据从dtype('O')转换为dtype('int64')

时间:2019-03-29 00:53:09

标签: python scikit-learn joblib

我使用sklearn创建了一个 VotingClassifier()对象。稍后,我使用joblib将其保存到 voting_predictor.pkl 文件中。成功加载数据时,当我尝试将某些数据预测为voting_predictor.predict(X_test)时,出现以下错误:

  
    

TypeError:无法根据规则“安全”将数组数据从dtype('O')转换为dtype('int64')

  

我尝试使用 pickle 转储/加载对象,但得到了相同的确切错误。代码如下:

eclf1 = VotingClassifier(estimators=estimators, voting='hard')

eclf1 = eclf1.fit(X_train, y_train)
y_pred = eclf1.predict(X_test)

report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)

print(report)
print(poll_accuracy)

# successful object dump
filename = 'voting_predictor.pkl'
joblib.dump(eclf1, filename)

#successful object load
voting_predictor = joblib.load(filename)
# this prints the object correctly, showing all its parameters 
print(voting_predictor)

#error shows here
y_pred = voting_predictor.predict(X_test)

report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)

print(voting_predictor)成功打印对象及其所有参数。为什么会这样?

2 个答案:

答案 0 :(得分:0)

在将Catbooster与其他预测变量合在一起时,我遇到了相同的错误。 我找到了this解决方案,但是我正在寻找一种更优雅的解决方案。

答案 1 :(得分:0)

问题在于目标列是类的名称,如字符串。似乎保留该字符串值而未将其标签编码为某个整数会导致此错误。但是,在任何其他情况下,sklearn都会正确处理每个类的字符串名称,并提供所有指标(例如,classification_report和precision_score)而不会出错。仅当我从文件中加载对象时,才会发生错误。