我继承了一些TF代码,每个请求执行以下操作:
def predict_tf(ml_wrapper, prediction_row_df):
log.debug('request POST {}'.format(prediction_row_df))
prediction_row_df, _, _ = ml_wrapper._engineer_features(prediction_row_df)
# As method says; delete stuff we don't want and scale and impute if needed
ml_wrapper._delete_unused_values_and_scale_and_impute_missing_values(prediction_row_df, single_row_prediction=True)
features, labels = ml_wrapper._split_features_and_labels(prediction_row_df)
panda_function_for_prediction = tf.estimator.inputs.pandas_input_fn(
features,
labels,
batch_size=ml_wrapper.batch_size,
num_epochs=1,
shuffle=False
)
predictions = ml_wrapper.tf_model.predict(
input_fn=panda_function_for_prediction)
probas = list(predictions)[0]['probabilities']
log.warning('PREDICTED: no:{} yes:{}'.format(probas[0], probas[1]))
return probas
代码似乎有效,但在控制台中我看到的内容如下:
2018-06-06 16:32:46,767 INFO [tensorflow:116] Calling model_fn.
2018-06-06 16:32:50,848 INFO [tensorflow:116] Done calling model_fn.
2018-06-06 16:32:51,082 INFO [tensorflow:116] Graph was finalized.
2018-06-06 16:32:51,083 INFO [tensorflow:116] Restoring parameters from /model_tensorflow/model.ckpt-719
2018-06-06 16:32:51,494 INFO [tensorflow:116] Running local_init_op.
2018-06-06 16:32:51,536 INFO [tensorflow:116] Done running local_init_op.
这个操作似乎每个请求需要4个 - 有没有办法加载模型/估算器一次并预测它?
答案 0 :(得分:0)
如果您使用的是Estimator
,则可以尝试以下一种方法:
https://github.com/marcsto/rl/blob/master/src/fast_predict2.py
这基本上可以防止重新加载图形。