我使用tf.train.ExponentialMovingAverage()对训练过程中的权重进行平均以进行预测。所以我建立了这样的图。
with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
...
train_op = tf.train.AdamOptimizer(learning_rate=lr)
ema = tf.train.ExponentialMovingAverage(0.999)
model_vars = tf.get_collection("trainable_variables", "model")
with tf.control_dependencies([train_op]):
ema_op = ema.apply(model_vars)
这有效。 但是当我将模型导出为pb格式时。我无法理解的事情发生了。 简单代码:
variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session(config=config) as sess:
# sess.run(tf.global_variables_initializer())
saver.restore(sess, tf.train.latest_checkpoint(model_path))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
inputs = {'word_ids': tf.saved_model.utils.build_tensor_info(word_ids))
outputs = {'logits': tf.saved_model.utils.build_tensor_info(logits)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
# legacy_init_op = tf.group(tf.global_variables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],signature_def_map={"predict_scores": signature})
我收到此错误:
tensorflow.python.framework.errors_impl.FailedPreconditionError: 尝试使用未初始化的值 模型/模型/字符/ CNN-字符/ w_cnn_0 / ExponentialMovingAverage [[节点: 型号/型号/字符/ CNN字符/ w_cnn_0 / ExponentialMovingAverage / _274 = _SendT = DT_FLOAT,client_terminated = false,recv_device =“ / job:本地主机/副本:0 /任务:0 /设备:CPU:0”, send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”, send_device_incarnation = 1, tensor_name =“ edge_55_model / model / chars / CNN-char / w_cnn_0 / ExponentialMovingAverage”, _device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”]] [[节点:模型/模型/转换/ ExponentialMovingAverage / _299 = _Recv [_start_time = 0,client_terminated = false,recv_device =“ / job:localhost /副本:0 /任务:0 /设备:CPU:0”, send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”, send_device_incarnation = 1, tensor_name =“ edge_67_model / model / transitions / ExponentialMovingAverage”, tensor_type = DT_FLOAT, _device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](^ save_3 / ShardedFilename, ^ save_3 / SaveV2 / tensor_names,^ save_3 / SaveV2 / shape_and_slices)]
因此,我在以下列表中显示var的名称: variables_to_restore
,model/model/chars/CNN-char/w_cnn_0/ExponentialMovingAverage
在列表中。
问题是什么?我确实告诉保护程序从检查点恢复 model/model/chars/CNN-char/w_cnn_0/ExponentialMovingAverage
。