以下代码死锁:
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels],
shuffle=False)
# input_var = tf.Variable([0, 0, 0])
# images = input_var.assign(images) # TODO placeholder would work ?
# input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
tf_print = tf.Print(input_batch, [input_batch])
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()
但是如果我注释掉input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
并取消注释程序保持打印的注释行:
我c:\ tf_jenkins \ home \ workspace \ release-win \ device \ cpu \ os \ windows \ tensorflow \ core \ kernels \ logging_ops.cc:79] [1 2 3]
答案 0 :(得分:1)
我不确定你的第一个问题,但我相信当你创建MonitoredTrainingSession时它会尝试初始化你的图形变量。但在您的情况下,其中一个变量初始值依赖于隐藏在tf.train.slice_input_producer
后面的出列操作。由于队列尚未启动,因此代码死锁等待队列排队。
在您的注释实现中,init_op
确实运行,因此队列可以启动并使您的代码正常工作。
以下是您的第二个问题的解释。
StopAtStepHook
依赖于正在更新的global_step
张量,而脚本中并非如此。这段代码
tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1))
会有效:基本上它会将tf.Print
操作和global_step
增量分组在一起,因此每次运行tf_print
时,global_step
都会递增。< / p>
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels], shuffle=False)
input_var = tf.Variable([0, 0, 0])
images = input_var.assign(images) # TODO placeholder would work ?
input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
tf_print = tf.group(tf.Print(input_batch, [input_batch]),
tf.assign_add(global_step, 1))
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()