不使用主管时,Tensorflow会冻结

时间:2017-04-27 23:18:53

标签: python tensorflow

没有GPU,没有队列,Tensorflow 1.1.0

这是一个LSTM代码示例:

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

此代码有效。它打印出训练过程信息,很酷。现在,我尝试使用freeze_graph()将经过训练的模型图写入磁盘,最后我发现此LSTM教程使用Supervisor来训练模型,并Supervisor冻结图形,并且freeze_graph()程序中不能使用冻结图。

我尝试从Supervisor切换到使用普通会话。所做的更改是在main()过程中(除了导入一些内容)。它现在看起来像这样(更改的部分突出显示,我删除了所有图形保存相关的东西,这不是问题):

with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(
            -config.init_scale, config.init_scale)
        with tf.name_scope("Train"):
            train_input = PTBInput(
                config=config, data=train_data, name="TrainInput")
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                m = PTBModel(
                    is_training=True, config=config, input_=train_input)
            tf.summary.scalar("Training Loss", m.cost)
            tf.summary.scalar("Learning Rate", m.lr)
        with session.Session() as sess:  # CHANGED
            sess.run(variables.global_variables_initializer())  # CHANGED
            for i in range(config.max_max_epoch):
                lr_decay = config.lr_decay ** max(i +
                                                  1 - config.max_epoch, 0.0)
                m.assign_lr(sess, config.learning_rate * lr_decay)
                print("Epoch: %d Learning rate: %.3f" %
                      (i + 1, sess.run(m.lr)))
                train_perplexity = run_epoch(sess, m, eval_op=m.train_op,
                                             verbose=True)
                print("Epoch: %d Train Perplexity: %.3f" %
                      (i + 1, train_perplexity))

在这些变化之后,整个事情开始冻结在这一行:

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py#L300

这是模型内部的session.run()调用(对Ctrl + C没有反应,对kill -9可以反击):

vals = session.run(fetches, feed_dict)

以前的session.run()来电(有一些)工作得很好。

我做错了什么?似乎所有变量都被初始化得很好(这是由原始代码中的Supervisor完成的)。有什么想法吗?

1 个答案:

答案 0 :(得分:2)

当您使用tf.train.Supervisor时,框架代码会在会话开始时自动调用tf.train.start_queue_runners(sess)(以及初始化变量)。如果切换回使用原始tf.Session,则必须手动调用此方法以启动输入管道。如下所示的更改应该有效:

# ...
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  tf.train.start_queue_runners(sess)
  # ...