.pb文件的推理结果与.h5不匹配

时间:2019-06-17 09:33:35

标签: python tensorflow google-cloud-platform

我已经尝试将ai模型部署到GCP的AI平台上,并且看起来很成功,但是与本地h5模型推断结果相比,推断结果不正确。

我现在正在使用tensorflow == 1.2.1,Keras == 2.0.6,Python 3.5.3。

我使用K.set_learning_phase()来区分训练/推理阶段,使用先前的配置/权重重新创建模型,并通过SavedModelBuilder保存了新模型。

def save_model_for_production(model, version, path='prod_models'):
    K.set_learning_phase(0)  # all new operations will be in test mode from now on

    # serialize the model and get its weights, for quick re-building
    config = model.get_config()
    weights = model.get_weights()

    # re-build a model where the learning phase is now hard-coded to 0

    new_model = model.from_config(config)
    # from keras.models import model_from_config
    # new_model = model_from_config(config)
    # new_model = Model.from_config(config)
    new_model.set_weights(weights)


    model_input = tf.saved_model.utils.build_tensor_info(new_model.input)   # deprecated
    model_output = tf.saved_model.utils.build_tensor_info(new_model.output)

    prediction_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={tf.saved_model.signature_constants.PREDICT_INPUTS: model_input},
            outputs={tf.saved_model.signature_constants.PREDICT_OUTPUTS: model_output},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    with K.get_session() as sess:
        sess.run(tf.global_variables_initializer())
        #init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        #sess.run(init_op)
        #sess.run(tf.saved_model.main_op.main_op())

        if not os.path.exists(path):
            os.mkdir(path)

        export_path = os.path.join(
            tf.compat.as_bytes(path),
            tf.compat.as_bytes(version))
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        builder.add_meta_graph_and_variables(
            sess=sess, tags=[tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict':
                   prediction_signature
            })
        builder.save()

我猜该变量未正确初始化。我尝试了几种tf.global_variables_initializer()tf.local_variables_initializer()等,推断结果互不相同。 (例如[1.0],[0.0],[3.2314555e-13])

但是我没有得到正确的结果。

如果有人知道如何解决此问题,我将不胜感激。

1 个答案:

答案 0 :(得分:0)

似乎已解决此问题。

我没有调用任何初始化程序(例如tf.global_variables_initializer),所以模型可以正常工作。

我无法确切解释原因,但是我猜想在使用K.get_session()进行会话时,不需要初始化变量。