如何使用scikitlearn保存一个热编码模型并预测新的未编码数据?

时间:2019-05-22 07:45:01

标签: python machine-learning scikit-learn

我的数据集包含3个分类特征,我使用一种热编码将其更改为二进制格式,并且一切正常。但是,当我想保存经过训练的模型并预测新的原始数据时,输入的内容未按我的预期进行编码,并且会导致错误。

combined_df_raw2= pd.concat([train_x_raw,unknown_test_df])
combined_df2 = pd.get_dummies(combined_df_raw2, columns=nominal_cols, 
drop_first=True)

encoded_unknown_df = combined_df2[len(train_x_raw):]

classifier = DecisionTreeClassifier(random_state=17)
classifier.fit(train_x_raw, train_Y)

pred_y = classifier.predict(encoded_unknown_df)

#here I use joblib to save my model and load it again
joblib.dump(classifier, 'savedmodel')
imported_model = joblib.load('savedmodel')

#here I input unencoded raw data for predict and got error that cannot             
convert 'tcp' to float, means that it is not encoded 

imported_model.predict([0,'tcp','vmnet','REJ',0,0,0,23])   

ValueError:无法将字符串转换为float:'tcp'

3 个答案:

答案 0 :(得分:1)

模型是在对分类变量进行编码后进行训练的,因此,必须在对各个变量进行一次热编码之后才能给出输入。

您可以使用sklearn.preprocessing中的OneHotEncoder,对测试数据进行编码,然后将其提供给模型。

答案 1 :(得分:0)

@chintan然后例如获取即将到来的原始数据,如果您转换仅具有一个实例的类别变量,则它将仅增加一列,而在拥有类别列之前,则好像有500列。因此不会再匹配。 以货币为例,一个实例即将使用INR,即使您进行编码,它也会将其转换为列, 但在您为世界上所有的粗面粉列之前

答案 2 :(得分:0)

使用fit(),然后使用transform(),这样一来,您就可以在安装好一个热编码器后对其进行腌制。

from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder(handle_unknown='ignore')
X = [['Male', 1], ['Female', 3], ['Female', 2]]
enc.fit(X)

然后让您腌制编码器,您可以使用其他方式保留编码器。签出https://scikit-learn.org/stable/modules/model_persistence.html

import pickle
with open('encoder.pickle', 'wb') as f:
    pickle.dump(enc, f)

现在,当您有新的数据可以预测时,您必须首先遍历对训练数据所做的整个预处理流程。在这种情况下,编码器。让我们将其加载回去,

with open('encoder.pickle', 'rb') as f:
    enc = pickle.loads(f)

加载后,只需转换新数据即可。

enc.transform(new_data)

要了解有关泡菜的更多信息,https://docs.python.org/3/library/pickle.html