如何保持查找表的初始化以进行预测(而不仅仅是培训)?

时间:2017-05-29 06:43:17

标签: python tensorflow tensorflow-serving

我使用训练数据(作为输入)从tf.contrib.lookup创建一个查找表。然后,我通过该查找表传递每个输入,然后将其传递给我的模型。

这适用于培训,但是当涉及到同一模型的在线预测时,它会引发错误:

  

表未初始化

我正在使用tanh来保存模型。我从这个保存的模型中运行预测。

如何初始化此表以使其保持初始化状态?或者有更好的方法来保存模型,以便始终初始化表吗?

2 个答案:

答案 0 :(得分:9)

我认为您最好使用tf.tables_initializer()作为legacy_init_op

除了表初始化之外,

tf.saved_model.main_op.main_op()还添加了本地和全局初始化操作。 当您加载已保存的模型并运行legacy_init_op时,它会重置您的变量,这不是您想要的。

答案 1 :(得分:5)

您可以指定"初始化"使用main_oplegacy_init_op kwarg将元图添加到包含tf.saved_model.builder.SavedModelBuilder.add_meta_graph的SavedModel包时的操作。如果需要多个操作,您可以使用单个操作,也可以将许多操作与tf.group组合在一起。

请注意,在Cloud ML Engine中,您必须使用legacy_init_op。但是,在将来runtime_version中,您将可以使用main_op (IIRC,从runtime_version == 1.2开始)

saved_model模块提供了一个内置的tf.saved_model.main_op.main_op来封装单个操作中的常见初始化操作(局部变量初始化和表初始化)。

总而言之,代码应该如下所示(改编自this example):

  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export', name))

  # signature_def gets constructed here

  with tf.Session(graph=prediction_graph) as session:
    # Need to be initialized before saved variables are restored
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    # Restore the value of the saved variables
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
        # Relevant change to the linked example is here!
        legacy_init_op=tf.saved_model.main_op.main_op()
    )

注意:如果您使用的是高级库(例如tf.estimator),则这应该是默认值,如果您需要指定其他初始化操作,则可以将它们指定为tf.train.Scaffold的一部分您传递给model_fn中的tf.estimator.EstimatorSpec的对象。