How to use conditions and temporary variables in tensorflow?

时间:2018-03-25 19:42:18

标签: tensorflow

I want to create a temporary variable in TF and then substract it from my input variable if it is a traning phase. This is simplified code that I use. Please, could you give me a piece of advice how to make it work?

Please, keep in mind that I don't want to create a variable if it is not traning phase.

import tensorflow as tf

def some_transformation(x):
    x0 = tf.get_variable('x0', initializer=tf.random_uniform([1], 
                         maxval=0.3, dtype=tf.float32), dtype=tf.float32)
    return tf.subtract(x, x0)

x = tf.placeholder("float", [])
is_traning = tf.placeholder(tf.int32, None)

x_transformed = tf.cond(is_traning > 0, lambda: some_transformation(x), lambda: x)
#x_transformed = some_transformation(x)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    out = sess.run(x_transformed, feed_dict={x: 10, is_traning: 1})
    print(out)

1 个答案:

答案 0 :(得分:0)

请在运行代码后发布错误消息和代码,我发现您收到此错误:

ValueError: Initializer for variable x0/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

这似乎是因为您试图从some_transformation中调用tf.get_variable,并且它会告诉您将属性initializer=tf.random_uniform(...)更改为initializer=lambda: tf.random_uniform(...)

您也可以选择在转换之外定义x0并将其传递为:

x_transformed = tf.cond(is_traning > 0, lambda: some_transformation(x, x0), lambda: x)

如果这在您的使用案例中有效。