在Tensorflow中共享线程之间的变量

时间:2016-03-16 12:54:15

标签: python multithreading tensorflow

我正在尝试使用Python线程使用TensorFlow实现异步梯度下降。在主代码中,我定义了图形,包括一个训练操作,它获取一个变量以保持global_step的计数:

with tf.variable_scope("scope_global_step") as scope_global_step:
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)

如果我打印global_step的名称,我会:

  

scope_global_step / global_step:0

主代码还启动了几个线程来执行training方法:

threads = [threading.Thread(target=training, args=(sess, train_op, loss, scope_global_step)) for i in xrange(NUM_TRAINING_THREADS)]
for t in threads: t.start()

如果global_step的值大于或等于FLAGS.max_steps,我希望每个线程停止执行。为此,我按照以下方式构建training方法:

def training(sess, train_op, loss, scope_global_step):
    while (True):
         _, loss_value = sess.run([train_op, loss])
         with tf.variable_scope(scope_global_step, reuse=True): 
            global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
            global_step = global_step.eval(session=sess)
            if global_step >= FLAGS.max_steps: break

此操作失败并显示以下消息:

  

ValueError:欠共享:变量scope_global_step / global_step不存在,不允许。你的意思是在VarScope中设置reuse = None吗?

我可以看到:0在首次创建时添加到变量的名称中,当我尝试检索它时,不使用该后缀。为什么是这样? 如果我在尝试检索变量时手动将后缀添加到变量的名称,它仍然声称该变量不存在。为什么TensorFlow找不到变量?不应该在线程之间自动共享变量吗?我的意思是,所有线程都在同一个会话中运行,对吧?

另一个与我的training方法相关的问题:global_step.eval(session=sess)会再次执行图表,还是只会在gloabl_step执行后获取分配给train_op的值和loss操作?一般来说,从Python代码中使用的变量中获取值的推荐方法是什么?

1 个答案:

答案 0 :(得分:1)

TL; DR:将您在第一个代码片段中创建的global_step tf.Variable对象作为训练线程参数之一传递,并调用sess.run(global_step)在传入的变量上。

作为一般规则,您的训练循环(尤其是单独线程中的训练循环)不应修改图形。 tf.variable_scope()上下文管理器和tf.get_variable() 可以修改图表(即使它们并非总是如此),因此您不应在训练循环中使用它们。最安全的做法是在创建训练线程时将global_step对象(您在第一时间创建的对象)作为args元组之一传递。然后你可以简单地重写你的训练功能:

def training(sess, train_op, loss, global_step):
    while (True):
         _, loss_value = sess.run([train_op, loss])
         current_step = sess.run(global_step)
         if current_step >= FLAGS.max_steps: break

要回答您的其他问题,运行global_step.eval(session=sess)sess.run(global_step)只会获取global_step变量的当前值,而不会重新执行图表的其余部分。这是获取用于Python代码的tf.Variable值的推荐方法。