如何在不显式运行其输出的情况下使张量流分配成为计算图的一部分?

时间:2018-07-18 21:02:42

标签: tensorflow

我试图在张量流中创建自定义梯度,以实现本文建议的对数的指数平滑(无偏)梯度(https://arxiv.org/pdf/1801.04062.pdf)。我需要做的是创建一个存储指数平滑值的新变量,该变量将被更新并在自定义渐变函数中使用。此外,我需要一个标志来告诉我何时完成了第一个梯度计算,因此我可以将指数平滑值初始化为适当的值(与数据相关)。此外,自定义渐变函数的输出必须仅是渐变,因此从自定义渐变内部访问tf.assign的输出会很麻烦。最后,我不想创建第二个操作来“手动”通过在训练循环中单独运行来初始化指数平滑。无论如何,这都太复杂了,因此我在下面概述了一个抽象但简单的问题,该问题的解决方案可以解决我的问题:

我需要做的是以一个条件为条件来更新一个变量,而且我还需要更新第二个变量而不将其作为函数的显式输出提供。演示我的问题的示例代码如下:

import tensorflow as tf

a = tf.get_variable(name = "test",initializer=True)
b = tf.get_variable(name = "testval",initializer = 10.)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

def make_function(inp):
    with tf.variable_scope("",reuse = True):
        a = tf.get_variable(name = "test",dtype = tf.bool)
        b = tf.get_variable(name = "testval")

    iftrue = lambda: [tf.assign(b,inp),tf.assign(a,False)]
    iffalse = lambda: [tf.assign(b,(b + inp)/2),tf.assign(a,False)]

    acond,bcond = tf.cond(a,iftrue,iffalse)

    return acond

I = tf.placeholder(tf.float32)

tcond = make_function(I)

print("{}\tThe initial values of a and b".format(sess.run([a,b])))
print("{}\t\tRun, tcond1. output is the updated value of b.".format(sess.run(tcond,{I:1})))
print("{}\tNow we see that b has been updated, but a has not.".format(sess.run([a,b])))
print("{}\t\tSo now the value is 2 instead of 1.5 like it should be.".format(sess.run(tcond,{I:2})))

输出为:

[True, 10.0]    The initial values of a and b
1.0     Run, tcond1. output is the updated value of b.
[True, 1.0] Now we see that b has been updated, but a has not.
2.0     So now the value is 2 instead of 1.5 like it should be.

现在,我了解到我需要像sess.run(acond)这样的一行,其中acondmake_function中条件的输出,但是我不能返回它,因为我的函数需要只返回b的值(而不是a),我不想带走一个额外的操作,而我需要记住要在第一次训练迭代中运行,而不必在其他人。

那么,有没有一种方法可以将分配操作acond添加到计算图中,而无需显式返回并运行它sess.run

1 个答案:

答案 0 :(得分:0)

将此操作添加到自定义集合中,然后在最终操作(例如train_op)和您的acond之间创建依赖关系。

在方法内部:

tf.add_to_collection("to_run", acond)

在最终操作的定义中:

to_run = tf.get_collection("to_run")
with tf.control_dependencies(to_run):
    final_op = <something>

运行final_op时,可以确保您的acond已经执行。