Python TensorFlow:如何使用优化器和import_meta_graph重新启动训练?

时间:2017-04-06 00:03:18

标签: tensorflow

我正试图通过拾取它停止的地方重新启动TensorFlow中的模型训练。我想使用最近添加的(0.12 +我认为)ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used,以便不重建图形。

我见过这方面的解决方案,例如Tensorflow: How to save/restore a model?,但是我遇到了AdamOptimizer的问题,特别是我遇到了import_meta_graph()错误。 This can be fixed by initializing,但我的模型值已被清除!

还有其他答案和一些完整的示例,但它们似乎总是较旧,因此不包括较新的{{1}}方法,或者没有非张量优化器。我能找到的最接近的问题是tensorflow: saving and restoring session,但没有最终的明确解决方案,而且这个例子非常复杂。

理想情况下,我想要一个简单的可运行示例,从头开始,停止,然后重新开始。我有一些有用的东西(下图),但也想知道我是否遗漏了什么。当然,我不是唯一这样做的人吗?

4 个答案:

答案 0 :(得分:8)

以下是我从阅读文档,其他类似解决方案以及反复试验中得出的结果。它是随机数据的简单自动编码器。如果跑,然后再跑,它将从它停止的地方继续(即首次运行的成本函数从~0.5 - > 0.3秒开始运行~0.3)。除非我遗漏了一些东西,所有的存储,构造函数,模型构建,add_to_collection都需要并且按照精确的顺序,但可能有一种更简单的方法。

是的,在这里不需要加载import_meta_graph图表,因为代码就在上面,但是我想要的是实际的应用程序。

from __future__ import print_function
import tensorflow as tf
import os
import math
import numpy as np

output_dir = "/root/Data/temp"
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")

input_length = 10
encoded_length = 3
learning_rate = 0.001
n_epochs = 10
n_batches = 10
if not os.path.exists(model_checkpoint_file_base + ".meta"):
    print("Making new")
    brand_new = True

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in")
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length],
                                          -1.0 / math.sqrt(input_length),
                                          1.0 / math.sqrt(input_length)), name="W_enc")
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc")
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded")
    W_dec = tf.transpose(W_enc, name="W_dec")
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec")
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded")
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")

    saver = tf.train.Saver()
else:
    print("Reloading existing")
    brand_new = False
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")
    g = tf.get_default_graph()
    x_in = g.get_tensor_by_name("x_in:0")
    cost = g.get_tensor_by_name("cost:0")


sess = tf.Session()
if brand_new:
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    init = tf.global_variables_initializer()
    sess.run(init)
    tf.add_to_collection("optimizer", optimizer)
else:
    saver.restore(sess, model_checkpoint_file_base)
    optimizer = tf.get_collection("optimizer")[0]

for epoch_i in range(n_epochs):
    for batch in range(n_batches):
        batch = np.random.rand(50, input_length)
        _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch})
        print("batch_cost:", curr_cost)
        save_path = tf.train.Saver().save(sess, model_checkpoint_file_base)

答案 1 :(得分:2)

我遇到了同样的问题,我只知道出了什么问题,至少在我的代码中是这样。

最后,我在saver.restore()中使用了错误的文件名。必须为此函数指定不带文件扩展名的文件名,就像saver.save()函数:

一样
saver.restore(sess, 'model-1')

而不是

saver.restore(sess, 'model-1.data-00000-of-00001')

有了这个,我就完全按照自己的意愿行事:从头开始,停止,然后重新开始。我不需要使用tf.train.import_meta_graph()函数从元文件初始化第二个保护程序,并且在初始化优化程序后我不需要显式地声明tf.initialize_all_variables()

我的完整模型恢复如下所示:

with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'model-1')

我认为在协议V1中你仍然需要将.ckpt添加到文件名中,对于import_meta_graph(),你仍然需要添加.meta,这可能会导致用户之间的混淆。也许这应该在文档中更明确地指出。

答案 2 :(得分:1)

在恢复会话中创建保护程序对象时可能会出现问题。

在恢复会话中使用以下代码时,我获得了与您相同的错误。

lista = [
    {'flag': True, 'value': 0}, {'flag': True, 'value': 5}, {'flag': True, 'value': 10}, 
    {'flag': False, 'value': 15}, {'flag': False, 'value': 20}, {'flag': False, 'value': 25}, 
    {'flag': False, 'value': 30}, {'flag': False, 'value': 35}, {'flag': False, 'value': 40}, 
    {'flag': True, 'value': 45}, {'flag': True, 'value': 50}, {'flag': True, 'value': 55}, 
    {'flag': True, 'value': 60}, {'flag': False, 'value': 65}, {'flag': False, 'value': 70}, 
    {'flag': False, 'value': 75}, {'flag': False, 'value': 80}, {'flag': False, 'value': 85}, 
    {'flag': True, 'value': 90}, {'flag': True, 'value': 95}
]
change = False
output = []
p=[0]*2
flagPast = lista[0]['flag']
for index,item in enumerate(lista):
    if(item['flag'] != flagPast):
        cp=p[:]       
        output.append(zip(['flag','values'],[flagPast,cp]))
        p[0]=item['value']
        change = True
    else:
        change = False
    if(not(change)):
        p[1]=item['value']
    flagPast = item['flag']
    if(index==(len(lista)-1)):
        p[1]=item['value']        
        cp=p[:]       
        output.append(zip(['flag','values'],[flagPast,cp])) 

但是当我以这种方式改变时,

saver = tf.train.import_meta_graph('tmp/hsmodel.meta')
saver.restore(sess, tf.train.latest_checkpoint('tmp/'))

错误消失了。 &#34; tmp / hsmodel&#34;是我在保存会话中给saver.save(sess,&#34; tmp / hsmodel&#34;)的路径。

有关存储和恢复训练MNIST网络(包含Adam优化器)会话的简单示例在此处。这有助于我与我的代码进行比较并解决问题。

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

答案 3 :(得分:0)

通过saver类,我们可以通过以下方式保存会话: saver.save(sess,“ checkpoints.ckpt”)

并允许我们恢复会话: saver.restore(sess,tf.train.latest_checkpoint(“ checkpoints.ckpt”))