TensorFlow检查点保存并读取

时间:2015-12-05 23:32:11

标签: python io tensorflow

我有一个基于TensorFlow的神经网络和一组变量。

训练功能如下:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'

我正在调用这样的训练功能:

# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)

我进行了这种培训,因为我需要将不同的数据集提供给我的模型。但是,如果我以这种方式调用train函数,TensorFlow将生成错误消息,指示它无法从文件中读取已保存的模型。

经过一些实验后,我发现这是因为检查点保存缓慢。在将文件写入磁盘之前,下一个列车功能将开始读取,从而产生错误。

我曾尝试使用time.sleep()函数在每次调用之间做一些延迟,但它没有工作。

任何人都知道如何解决这种写/读错误?非常感谢你!

1 个答案:

答案 0 :(得分:6)

您的代码中存在一个微妙的问题:每次调用train()函数时,对于所有模型变量和神经网络的其余部分,更多节点将添加到同一TensorFlow图中。这意味着每次构造tf.train.Saver()时,它都包含先前调用train()的所有变量。每次重新创建模型时,都会使用额外的_N后缀创建变量,以便为它们提供唯一的名称:

  1. 使用变量var_avar_b构建保护程序。
  2. 使用变量var_avar_bvar_a_1var_b_1构建保护程序。
  3. 使用变量var_avar_bvar_a_1var_b_1var_a_2var_b_2构建保护程序。
  4. tf.train.Saver的默认行为是将每个变量与相应op的名称相关联。这意味着var_a_1不会从var_a初始化,因为它们最终会有不同的名称。

    解决方案是每次调用train()时创建一个新图表。解决此问题的最简单方法是更改​​主程序,为每次调用train()创建一个新图表,如下所示:

    # First train
    with tf.Graph().as_default():
        train(False, 1)
    
    # Following train
    for i in xrange(10):
        with tf.Graph().as_default():
            train(True, 10)
    

    ...或者,等效地,您可以在with函数内移动train()块。