根据请求停止Tensorflow重新加载模型

时间:2018-06-06 14:35:31

标签: python tensorflow tensorflow-serving tensorflow-estimator

我继承了一些TF代码,每个请求执行以下操作:

def predict_tf(ml_wrapper, prediction_row_df):
    log.debug('request POST {}'.format(prediction_row_df))

    prediction_row_df, _, _ = ml_wrapper._engineer_features(prediction_row_df)
    # As method says; delete stuff we don't want and scale and impute if needed
    ml_wrapper._delete_unused_values_and_scale_and_impute_missing_values(prediction_row_df, single_row_prediction=True)
    features, labels = ml_wrapper._split_features_and_labels(prediction_row_df)
    panda_function_for_prediction = tf.estimator.inputs.pandas_input_fn(
        features,
        labels,
        batch_size=ml_wrapper.batch_size,
        num_epochs=1,
        shuffle=False
    )
    predictions = ml_wrapper.tf_model.predict(
        input_fn=panda_function_for_prediction)
    probas = list(predictions)[0]['probabilities']
    log.warning('PREDICTED: no:{} yes:{}'.format(probas[0], probas[1]))
    return probas

代码似乎有效,但在控制台中我看到的内容如下:

2018-06-06 16:32:46,767 INFO  [tensorflow:116] Calling model_fn.
2018-06-06 16:32:50,848 INFO  [tensorflow:116] Done calling model_fn.
2018-06-06 16:32:51,082 INFO  [tensorflow:116] Graph was finalized.
2018-06-06 16:32:51,083 INFO  [tensorflow:116] Restoring parameters from /model_tensorflow/model.ckpt-719
2018-06-06 16:32:51,494 INFO  [tensorflow:116] Running local_init_op.
2018-06-06 16:32:51,536 INFO  [tensorflow:116] Done running local_init_op.

这个操作似乎每个请求需要4个 - 有没有办法加载模型/估算器一次并预测它?

1 个答案:

答案 0 :(得分:0)

如果您使用的是Estimator,则可以尝试以下一种方法: https://github.com/marcsto/rl/blob/master/src/fast_predict2.py 这基本上可以防止重新加载图形。