tensorflow tf.contrib.learn.SVM如何重新加载训练模型并使用predict来对新数据进行分类

时间:2017-07-08 09:35:44

标签: python tensorflow svm tflearn

使用tensorflow tf.contrib.learn.SVM和保存模型训练svm模型;代码

feature_columns = [tf.contrib.layers.real_valued_column(feat) for feat in self.feature_columns]
model_dir = os.path.join(define.root, 'src', 'static_data', 'svm_model_dir')
model = svm.SVM(example_id_column='example_id',
                feature_columns=feature_columns,
                 model_dir=model_dir,
                            config=tf.contrib.learn.RunConfig(save_checkpoints_secs=10))
model.fit(input_fn=lambda: self.input_fun(self.df_train), steps=10000)
results = model.evaluate(input_fn=lambda: self.input_fun(self.df_test), steps=5, metrics=validation_metrics)
for key in sorted(results):
    print('% s: % s' % (key, results[key]))

重新加载训练模型并使用预测对新数据进行分类?

1 个答案:

答案 0 :(得分:0)

培训时

您致电svm.SVM(..., model_dir),然后拨打fit()evaluate()方法。

测试时

您拨打svm.SVM(..., model_dir)然后可以调用predict()方法。 您的模型会在model_dir中找到经过培训的模型,并会加载经过训练的模型参数。

参考

Issue #3340 of TF