我正在尝试导出我的模型,我使用了tf.saved_model.builder.SavedModelBuilder.save。在我的模型中,我有一个word_embedding变量,在测试版本中,它是一个100k文件,保存效果很好。用完整版替换它后,手套300d,1GB。保存永远运行。
def main():
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
with tf.device('device:CPU:0'):
saver = tf.train.import_meta_graph('c:\\myModel\\model.weights\\.meta', clear_devices=True)
sess.run(init_op)
saver.restore(sess, "c:\\myModel\\model.weights\\")
# inputs
word_ids = sess.graph.get_tensor_by_name('word_ids:0')
sequence_lengths = sess.graph.get_tensor_by_name('sequence_lengths:0')
feature_vecs = sess.graph.get_tensor_by_name('feature_vecs:0')
# output
logits = sess.graph.get_tensor_by_name('proj/Reshape_1:0')
transitions = sess.graph.get_tensor_by_name('transitions:0')
word_ids_info = utils.build_tensor_info(word_ids)
sequence_lengths_info = utils.build_tensor_info(sequence_lengths)
feature_vecs_info = utils.build_tensor_info(sequence_lengths)
dropout_info = utils.build_tensor_info(word_ids)
logits_info = utils.build_tensor_info(logits)
transitions_info = utils.build_tensor_info(transitions)
prediction_signature = signature_def_utils.build_signature_def(
inputs = { 'word_ids' : word_ids_info, 'sequence_lengths' : sequence_lengths_info, 'feature_vecs' : feature_vecs_info},
outputs = { 'logits' : logits_info, 'transitions' : transitions_info },
method_name = signature_constants.PREDICT_METHOD_NAME)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder = saved_model_builder.SavedModelBuilder('c:\\myModel\\exported')
builder.add_meta_graph_and_variables(
sess,
[tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY : prediction_signature,
},
legacy_init_op=legacy_init_op)
builder.save(as_text=True)