当图包含大变量时,tf.saved_model.builder.SavedModelBuilder.save将永远运行

时间:2017-07-24 23:30:23

标签: tensorflow

我正在尝试导出我的模型,我使用了tf.saved_model.builder.SavedModelBuilder.save。在我的模型中,我有一个word_embedding变量,在测试版本中,它是一个100k文件,保存效果很好。用完整版替换它后,手套300d,1GB。保存永远运行。

def main():
  init_op = tf.global_variables_initializer()
      with tf.Session() as sess:
        with tf.device('device:CPU:0'):
      saver = tf.train.import_meta_graph('c:\\myModel\\model.weights\\.meta', clear_devices=True)
      sess.run(init_op)
      saver.restore(sess, "c:\\myModel\\model.weights\\")

      # inputs
      word_ids = sess.graph.get_tensor_by_name('word_ids:0')
      sequence_lengths = sess.graph.get_tensor_by_name('sequence_lengths:0')
      feature_vecs = sess.graph.get_tensor_by_name('feature_vecs:0')
      # output
      logits = sess.graph.get_tensor_by_name('proj/Reshape_1:0')
      transitions = sess.graph.get_tensor_by_name('transitions:0')

      word_ids_info = utils.build_tensor_info(word_ids)
      sequence_lengths_info = utils.build_tensor_info(sequence_lengths)
      feature_vecs_info = utils.build_tensor_info(sequence_lengths)
      dropout_info = utils.build_tensor_info(word_ids)
      logits_info = utils.build_tensor_info(logits)
      transitions_info = utils.build_tensor_info(transitions)

      prediction_signature = signature_def_utils.build_signature_def(
        inputs = { 'word_ids' : word_ids_info, 'sequence_lengths' : sequence_lengths_info, 'feature_vecs' : feature_vecs_info},
        outputs = { 'logits' : logits_info, 'transitions' : transitions_info },
        method_name = signature_constants.PREDICT_METHOD_NAME)

      legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

      builder = saved_model_builder.SavedModelBuilder('c:\\myModel\\exported')

      builder.add_meta_graph_and_variables(
        sess, 
        [tag_constants.SERVING],
        signature_def_map={
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY : prediction_signature,
        },
        legacy_init_op=legacy_init_op)

      builder.save(as_text=True)

0 个答案:

没有答案