我使用sklearn创建了一个 VotingClassifier()对象。稍后,我使用joblib将其保存到 voting_predictor.pkl 文件中。成功加载数据时,当我尝试将某些数据预测为voting_predictor.predict(X_test)
时,出现以下错误:
TypeError:无法根据规则“安全”将数组数据从dtype('O')转换为dtype('int64')
我尝试使用 pickle 转储/加载对象,但得到了相同的确切错误。代码如下:
eclf1 = VotingClassifier(estimators=estimators, voting='hard')
eclf1 = eclf1.fit(X_train, y_train)
y_pred = eclf1.predict(X_test)
report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)
print(report)
print(poll_accuracy)
# successful object dump
filename = 'voting_predictor.pkl'
joblib.dump(eclf1, filename)
#successful object load
voting_predictor = joblib.load(filename)
# this prints the object correctly, showing all its parameters
print(voting_predictor)
#error shows here
y_pred = voting_predictor.predict(X_test)
report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)
print(voting_predictor)
成功打印对象及其所有参数。为什么会这样?
答案 0 :(得分:0)
在将Catbooster与其他预测变量合在一起时,我遇到了相同的错误。 我找到了this解决方案,但是我正在寻找一种更优雅的解决方案。
答案 1 :(得分:0)
问题在于目标列是类的名称,如字符串。似乎保留该字符串值而未将其标签编码为某个整数会导致此错误。但是,在任何其他情况下,sklearn都会正确处理每个类的字符串名称,并提供所有指标(例如,classification_report和precision_score)而不会出错。仅当我从文件中加载对象时,才会发生错误。