如何保存/加载tensorflow contrib.learn回归量?

时间:2017-03-10 21:26:15

标签: python tensorflow save

我有一个tensorflow contrib.learn.DNNRegressor,我已经将其作为以下代码段的一部分进行了培训:

regressor = tf.contrib.learn.DNNRegressor(feature_columns=fc, 
                                          hidden_units=hu_array, 
                                          optimizer=tf.train.AdamOptimizer(
                                                       learning_rate=0.001,
                                                    ),
                                          enable_centered_bias=False,
                                          activation_fn=tf.tanh,
                                          model_dir="./models/my_model/",
                                          )

regressor.fit(x=training_features, y=training_labels, steps=10000)

经过训练的网络运行良好,我想在另一台机器上将其用作其他代码的一部分。我试过复制models / my_model目录,并构建一个新的DNNRegressor,只指向model_dir,但它要求我提供feature_columns和hidden_​​units定义。不应该通过存储在model_dir中的快照获得该信息吗?是否有更好的方法来保存/恢复性能良好的训练模型,以用作预测器,而无需单独保存feature_columns和hidden_​​units?

2 个答案:

答案 0 :(得分:1)

我想出了一些可行的东西 - 不理想,但它完成了工作。如果有人有更好的想法,我会全力以赴。

我将我的kwargs for DNNRegressor转换为dict,并使用了**运算符。然后我能够腌制kwargs dict,并从中重建DNNRegressor。 E.g:

reg_args = {'feature_columns': fc, 'hidden_units': hu_array, ...}
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
pickle.dump(reg_args, open('reg_args.pkl', 'wb'))

稍后,我通过以下方式重建:

reg_args = pickle.load(open('reg_args.pkl', 'rb'))
# On another machine and so my model dir path changed:
reg_args['model_dir'] = NEW_MODEL_DIR
regressor = tf.contrib.learn.DNNRegressor(**reg_args)

效果很好。我确信必须有更好的方法,但是现在如果有人试图找出tf.contrib.learn的解决方法,这是一个解决方案。

答案 1 :(得分:0)

培训时

您致电DNNRegressor(..., model_dir),然后拨打fit()evaluate()方法。

测试时

您拨打DNNRegressor(..., model_dir)然后可以调用predict()方法。 您的模型会在model_dir中找到经过培训的模型,并会加载经过训练的模型参数。

参考

Issue #3340 of TF