tensorflow的MonitoredTrainingSession和切片输入生成器中的死锁

时间:2017-06-02 15:41:38

标签: python tensorflow deadlock

以下代码死锁:

import tensorflow as tf

def train():
    """Stripped down and modified from cifar10.cifar10_train.train"""
    global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
    images = tf.constant([[1, 2, 3], [1, 2, 3]])
    labels = tf.constant([[1, 2, 3], [1, 2, 3]])
    images, labels = tf.train.slice_input_producer([images, labels],
                                                   shuffle=False)
    # input_var = tf.Variable([0, 0, 0])
    # images = input_var.assign(images) # TODO placeholder would work ?
    # input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
    input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
    tf_print = tf.Print(input_batch, [input_batch])
    with tf.train.MonitoredTrainingSession(
            hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(tf_print)

if __name__ == '__main__':
    train()

但是如果我注释掉input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])并取消注释程序保持打印的注释行:

  

我c:\ tf_jenkins \ home \ workspace \ release-win \ device \ cpu \ os \ windows \ tensorflow \ core \ kernels \ logging_ops.cc:79] [1 2 3]

  • 为什么会僵局?是不是像我一样使用额外的变量来解决这个问题?或者我应该以某种方式使用占位符?
  • 我错过了什么,经过3个步骤后它没有终止?

1 个答案:

答案 0 :(得分:1)

  1. 我不确定你的第一个问题,但我相信当你创建MonitoredTrainingSession时它会尝试初始化你的图形变量。但在您的情况下,其中一个变量初始值依赖于隐藏在tf.train.slice_input_producer后面的出列操作。由于队列尚未启动,因此代码死锁等待队列排队。 在您的注释实现中,init_op确实运行,因此队列可以启动并使您的代码正常工作。

  2. 以下是您的第二个问题的解释。 StopAtStepHook依赖于正在更新的global_step张量,而脚本中并非如此。这段代码 tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1))会有效:基本上它会将tf.Print操作和global_step增量分组在一起,因此每次运行tf_print时,global_step都会递增。< / p>

    import tensorflow as tf
    
    def train():
        """Stripped down and modified from cifar10.cifar10_train.train"""
        global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
        images = tf.constant([[1, 2, 3], [1, 2, 3]])
        labels = tf.constant([[1, 2, 3], [1, 2, 3]])
        images, labels = tf.train.slice_input_producer([images, labels], shuffle=False)
        input_var = tf.Variable([0, 0, 0])
        images = input_var.assign(images) # TODO placeholder would work ?
        input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
        tf_print = tf.group(tf.Print(input_batch, [input_batch]),
                            tf.assign_add(global_step, 1))
        with tf.train.MonitoredTrainingSession(
                hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(tf_print)
    
    if __name__ == '__main__':
        train()