如何在TensorFlow中正确更新while循环中的变量?

时间:2017-11-19 18:39:40

标签: tensorflow while-loop

有人可以解释(或指向我错过的文档中的相关位置)如何正确更新tf.Variable()中的tf.while_loop?我正在尝试更新循环中的变量,这些变量将使用assign()方法存储一些信息,直到循环的下一次迭代。但是,这没有做任何事情。

由于最小化程序正在更新mu_tfsigma_tf的值,而step_mu则不然,我显然做错了,但我不明白它是什么是。具体来说,我想我应该说我知道assign() does not do anything until it is executed when the graph is run,所以我知道我可以做到

sess.run(step_mu.assign(mu_tf))

并且会更新step_mu,但我想在循环中正确执行此操作。我不明白如何将assign操作添加到循环体中。

我正在做的简化工作示例如下:

import numpy as np
import tensorflow as tf

mu_true = 0.5
sigma_true = 1.5

n_events = 100000

# Placeholders
X = tf.placeholder(dtype=tf.float32)

# Variables
mu_tf = tf.Variable(initial_value=tf.random_normal(shape=[], mean=0., stddev=0.1,
                                                dtype=tf.float32),
                    dtype=tf.float32)
sigma_tf = tf.Variable(initial_value=tf.abs(tf.random_normal(shape=[], mean=1., stddev=0.1,
                                                dtype=tf.float32)),
                       dtype=tf.float32,
                       constraint=lambda x: tf.abs(x))

step_mu = tf.Variable(initial_value=-99999., dtype=tf.float32)   
step_loss = tf.Variable(initial_value=-99999., dtype=tf.float32)

# loss function
gaussian_dist = tf.distributions.Normal(loc=mu_tf, scale=sigma_tf)
log_prob = gaussian_dist.log_prob(value=X)
negative_log_likelihood = -1.0 * tf.reduce_sum(log_prob)

# optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)

# sample data
x_sample = np.random.normal(loc=mu_true, scale=sigma_true, size=n_events)

# Construct the while loop.
def cond(step):
    return tf.less(step, 10)

def body(step):
    # gradient step
    train_op = optimizer.minimize(loss=negative_log_likelihood)

    # update step parameters
    with tf.control_dependencies([train_op]):
        step_mu.assign(mu_tf)

        return tf.add(step,1)

loop = tf.while_loop(cond, body, [tf.constant(0)])

# Execute the graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    step_loss = sess.run(fetches=negative_log_likelihood, feed_dict={X: x_sample})

    print('Before loop:\n')
    print('mu_tf: {}'.format(sess.run(mu_tf)))
    print('sigma_tf: {}'.format(sess.run(sigma_tf)))
    print('step_mu: {}'.format(sess.run(step_mu)))
    print('step_loss: {}\n'.format(step_loss))

    sess.run(fetches=loop, feed_dict={X: x_sample})

    print('After loop:\n')
    print('mu_tf: {}'.format(sess.run(mu_tf)))
    print('sigma_tf: {}'.format(sess.run(sigma_tf)))
    print('step_mu: {}'.format(sess.run(step_mu)))
    print('step_loss: {}'.format(step_loss))

0 个答案:

没有答案