tf.saved_model.builder.SavedModelBuilder()

时间:2019-03-16 12:48:51

标签: python tensorflow

我使用tf.train.ExponentialMovingAverage()对训练过程中的权重进行平均以进行预测。所以我建立了这样的图。

with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
    ...
    train_op = tf.train.AdamOptimizer(learning_rate=lr)
    ema = tf.train.ExponentialMovingAverage(0.999)
    model_vars = tf.get_collection("trainable_variables", "model")
    with tf.control_dependencies([train_op]):
        ema_op = ema.apply(model_vars)

这有效。 但是当我将模型导出为pb格式时。我无法理解的事情发生了。 简单代码:

variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

with tf.Session(config=config) as sess:
    # sess.run(tf.global_variables_initializer())
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)

    inputs = {'word_ids': tf.saved_model.utils.build_tensor_info(word_ids))
    outputs = {'logits': tf.saved_model.utils.build_tensor_info(logits)}

    signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
    # legacy_init_op = tf.group(tf.global_variables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],signature_def_map={"predict_scores": signature})

我收到此错误:

  

tensorflow.python.framework.errors_impl.FailedPreconditionError:   尝试使用未初始化的值   模型/模型/字符/ CNN-字符/ w_cnn_0 / ExponentialMovingAverage [[节点:   型号/型号/字符/ CNN字符/ w_cnn_0 / ExponentialMovingAverage / _274 =   _SendT = DT_FLOAT,client_terminated = false,recv_device =“ / job:本地主机/副本:0 /任务:0 /设备:CPU:0”,   send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,   send_device_incarnation = 1,   tensor_name =“ edge_55_model / model / chars / CNN-char / w_cnn_0 / ExponentialMovingAverage”,   _device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”]]      [[节点:模型/模型/转换/ ExponentialMovingAverage / _299 =   _Recv [_start_time = 0,client_terminated = false,recv_device =“ / job:localhost /副本:0 /任务:0 /设备:CPU:0”,   send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,   send_device_incarnation = 1,   tensor_name =“ edge_67_model / model / transitions / ExponentialMovingAverage”,   tensor_type = DT_FLOAT,   _device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](^ save_3 / ShardedFilename,   ^ save_3 / SaveV2 / tensor_names,^ save_3 / SaveV2 / shape_and_slices)]

因此,我在以下列表中显示var的名称: variables_to_restoremodel/model/chars/CNN-char/w_cnn_0/ExponentialMovingAverage 在列表中。
问题是什么?我确实告诉保护程序从检查点恢复 model/model/chars/CNN-char/w_cnn_0/ExponentialMovingAverage

0 个答案:

没有答案