尚未安装RandomForestClassifier实例。使用此方法之前,请使用适当的参数调用“ fit”

时间:2018-07-18 08:52:35

标签: python machine-learning scikit-learn cross-validation grid-search

我正在尝试训练决策树模型,将其保存,然后在以后需要时重新加载它。但是,我不断收到以下错误:

  

此DecisionTreeClassifier实例尚未安装。称“适合”   在使用此方法之前,请先输入适当的参数。

这是我的代码:

X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.20, random_state=4)

names = ["Decision Tree", "Random Forest", "Neural Net"]

classifiers = [
    DecisionTreeClassifier(),
    RandomForestClassifier(),
    MLPClassifier()
    ]

score = 0
for name, clf in zip(names, classifiers):
    if name == "Decision Tree":
        clf = DecisionTreeClassifier(random_state=0)
        grid_search = GridSearchCV(clf, param_grid=param_grid_DT)
        grid_search.fit(X_train, y_train_TF)
        if grid_search.best_score_ > score:
            score = grid_search.best_score_
            best_clf = clf
    elif name == "Random Forest":
        clf = RandomForestClassifier(random_state=0)
        grid_search = GridSearchCV(clf, param_grid_RF)
        grid_search.fit(X_train, y_train_TF)
        if grid_search.best_score_ > score:
            score = grid_search.best_score_
            best_clf = clf

    elif name == "Neural Net":
        clf = MLPClassifier()
        clf.fit(X_train, y_train_TF)
        y_pred = clf.predict(X_test)
        current_score = accuracy_score(y_test_TF, y_pred)
        if current_score > score:
            score = current_score
            best_clf = clf


pkl_filename = "pickle_model.pkl"  
with open(pkl_filename, 'wb') as file:  
    pickle.dump(best_clf, file)

from sklearn.externals import joblib
# Save to file in the current working directory
joblib_file = "joblib_model.pkl"  
joblib.dump(best_clf, joblib_file)

print("best classifier: ", best_clf, " Accuracy= ", score)

这是我加载模型并进行测试的方式:

#First method
with open(pkl_filename, 'rb') as h:
    loaded_model = pickle.load(h) 
#Second method 
joblib_model = joblib.load(joblib_file)

如您所见,我尝试了两种保存方法,但没有一种起作用。

这是我的测试方式:

print(loaded_model.predict(test)) 
print(joblib_model.predict(test)) 

您可以清楚地看到这些模型实际上是适合的,如果我尝试使用其他任何模型(例如SVM或Logistic回归),该方法就可以正常工作。

1 个答案:

答案 0 :(得分:5)

问题出在这一行:

CREATE TABLE `item` (
  `id` char(36) COLLATE utf8_unicode_ci NOT NULL COMMENT '(DC2Type:uuid)',
  `item_group_id` char(36) COLLATE utf8_unicode_ci NOT NULL COMMENT '(DC2Type:uuid)',
  `content_id` char(36) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '(DC2Type:uuid)',
  `section_id` char(36) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '(DC2Type:uuid)',
  `person_id` char(36) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '(DC2Type:uuid)',
  `created` datetime NOT NULL,
  `updated` datetime NOT NULL,
  PRIMARY KEY (`id`),
  KEY `IDX_D4707EBD57B8F0DE` (`item_group_id`),
  KEY `IDX_D4707EBDD07ECCB6` (`content_id`),
  KEY `IDX_D4707EBDF639F774` (`section_id`),
  KEY `IDX_D4707EBD9395C3F3` (`person_id`),
  CONSTRAINT `FK_D4707EBD57B8F0DE` FOREIGN KEY (`item_group_id`) REFERENCES `item_group` (`id`),
  CONSTRAINT `FK_D4707EBD9395C3F3` FOREIGN KEY (`person_id`) REFERENCES `pseron` (`id`) ON DELETE SET NULL,
  CONSTRAINT `FK_D4707EBDD07ECCB6` FOREIGN KEY (`content_id`) REFERENCES `content` (`id`) ON DELETE SET NULL,
  CONSTRAINT `FK_D4707EBDF639F774` FOREIGN KEY (`section_id`) REFERENCES `section` (`id`) ON DELETE SET NULL
  ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci

您已将best_clf = clf 传递给clf,后者会克隆估算器并使数据适合那些克隆的模型。因此,您实际的grid_search保持不变和不适合。

您需要的是

clf

保存已拟合的best_clf = grid_search 模型。

如果您不想保存grid_search的全部内容,则可以使用grid_search的{​​{1}}属性来获取实际的克隆拟合模型。

best_estimator_