保存重新训练的张量流模型的问题

时间:2018-02-18 22:11:43

标签: python tensorflow machine-learning

我正在尝试加载模型(之前已保存),并在重新训练后保存。加载效果很好,但我面临如下保存问题:

sess=tf.Session()
sess.run(init)
loader = tf.train.import_meta_graph(self.model_path+'.meta')
loader.restore(sess,self.model_path)#tf.train.latest_checkpoint('./'))            
print('Model restored')
#retrain
saver=tf.train.Saver()
saver.save(sess, self.model_path)

第一次保存时,我没有遇到任何类似的问题,如下所示:

saver=tf.train.Saver()
sess=tf.Session()
sess.run(init)
#train
saver.save(sess, self.model_path)

我遇到的错误是:

File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1139, in __init__
    self.build()
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1170, in build
    restore_sequentially=self._restore_sequentially)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 673, in build
    saveables = self._ValidateAndSliceInputs(names_to_saveables)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 557, in _ValidateAndSliceInputs
    names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 535, in OpListToDict
    name)
ValueError: At least two variables have the same name: Variable_15/Adam

1 个答案:

答案 0 :(得分:0)

您会看到此消息,因为范围中有两个具有相同名称的变量。 tf.train.import_meta_graph从文件中读取图表,并将所有操作和张量添加到当前现有图表中。我很惊讶import_meta_graph首先甚至没有触发这样的例外。

请参阅完整示例以重现此行为:

import tensorflow as tf

# tiny graph
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
cost = tf.reduce_sum(x * output)
# create first time u'beta1_power:0', u'beta2_power:0'
train_op = tf.train.AdamOptimizer().minimize(cost)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, './adam/my_model')

    print([v.name for v in tf.global_variables()])

    # create second time u'beta1_power:0', u'beta2_power:0'
    meta_graph = tf.train.import_meta_graph('./adam/my_model.meta')
    meta_graph.restore(sess, './adam/my_model')

    print([v.name for v in tf.global_variables()])

    saver = tf.train.Saver(tf.global_variables())
    # exception as there are now two times: u'beta1_power:0', u'beta2_power:0'
    saver.save(sess, './adam/my_model2')

解决方案是

  • tf.reset_default_graph()
  • 之前使用tf.trainimport_meta_graph清除图表
  • tf.train.import_meta_graph
  • 使用新会话
  • 只需使用tf.train.Saver().restore(sess, '/tmp/model/my_model')
  • 加载权重