如何在TensorFlow中指定值while_loop

时间:2017-08-01 12:22:50

标签: python filter tensorflow while-loop time-series

目标是在TensorFlow中实现一个循环功能,以便随时间过滤信号。

input后来呈现为[batch, in_depth, in_height, in_width, in_channels]形式的5-D张量。我想使用tf.while_loop迭代in_depth并重新分配取决于之前时间步长值的值。但是,我无法在循环中重新分配变量值。

为了简化问题,我创建了问题的一维版本:

def condition(i, signal):
    return tf.less(i, signal.shape[0])

def operation(i, signal):
    signal = tf.get_variable("signal")
    signal = signal[i].assign(signal[i-1]*2)
    i = tf.add(i, 1)
    return (i, signal)

with tf.variable_scope("scope"):
    i = tf.constant(1)
    init = tf.constant_initializer(0)
    signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
    signal = tf.assign(signal[0], 1.2)

with tf.variable_scope("scope", reuse = True):
    loops_vars = [i, signal]
    i, signal = tf.while_loop(condition, operation, loop_vars, back_prop = False)

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    i, signal = session.run([i, signal])

tf.assign返回一个操作,该操作必须在会话中运行才能进行评估(see here for further details)。

我预计,TensorFlow会在循环中链接操作,因此在运行会话并请求signal后执行分配。但是,当我执行给定代码并打印结果时,signal contatins [1.2, 0, 0, 0]i包含(按预期方式)4

我在这里有什么误解,如何更改代码,以便重新分配signal的值?

1 个答案:

答案 0 :(得分:1)

虽然循环变量仅通过body函数的返回值更新,但您不应使用自己的赋值操作。相反,您需要在循环后返回您希望signal拥有的值,与i一样。

此外,您不应该在身体或条件中使用tf.get_variable,只需使用您收到的参数。

# ...
def operation(i, signal):
    shape = signal.shape
    signal = tf.concat([signal[:i], [signal[i - 1] * 2], signal[i + 1:]], axis=0)
    signal.set_shape(shape) # Shapes have to be invariant for the loop
    i = tf.add(i, 1)
    return (i, signal)

with tf.variable_scope("scope"):
    i = tf.constant(1)
    init = tf.constant_initializer(1.2) # init signal here and avoid tf.assign
    signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
    # signal = tf.assign(signal[0], 1.2) 

# ...