有人可以解释(或指向我错过的文档中的相关位置)如何正确更新tf.Variable()
中的tf.while_loop
?我正在尝试更新循环中的变量,这些变量将使用assign()
方法存储一些信息,直到循环的下一次迭代。但是,这没有做任何事情。
由于最小化程序正在更新mu_tf
和sigma_tf
的值,而step_mu
则不然,我显然做错了,但我不明白它是什么是。具体来说,我想我应该说我知道assign()
does not do anything until it is executed when the graph is run,所以我知道我可以做到
sess.run(step_mu.assign(mu_tf))
并且会更新step_mu
,但我想在循环中正确执行此操作。我不明白如何将assign
操作添加到循环体中。
我正在做的简化工作示例如下:
import numpy as np
import tensorflow as tf
mu_true = 0.5
sigma_true = 1.5
n_events = 100000
# Placeholders
X = tf.placeholder(dtype=tf.float32)
# Variables
mu_tf = tf.Variable(initial_value=tf.random_normal(shape=[], mean=0., stddev=0.1,
dtype=tf.float32),
dtype=tf.float32)
sigma_tf = tf.Variable(initial_value=tf.abs(tf.random_normal(shape=[], mean=1., stddev=0.1,
dtype=tf.float32)),
dtype=tf.float32,
constraint=lambda x: tf.abs(x))
step_mu = tf.Variable(initial_value=-99999., dtype=tf.float32)
step_loss = tf.Variable(initial_value=-99999., dtype=tf.float32)
# loss function
gaussian_dist = tf.distributions.Normal(loc=mu_tf, scale=sigma_tf)
log_prob = gaussian_dist.log_prob(value=X)
negative_log_likelihood = -1.0 * tf.reduce_sum(log_prob)
# optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
# sample data
x_sample = np.random.normal(loc=mu_true, scale=sigma_true, size=n_events)
# Construct the while loop.
def cond(step):
return tf.less(step, 10)
def body(step):
# gradient step
train_op = optimizer.minimize(loss=negative_log_likelihood)
# update step parameters
with tf.control_dependencies([train_op]):
step_mu.assign(mu_tf)
return tf.add(step,1)
loop = tf.while_loop(cond, body, [tf.constant(0)])
# Execute the graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step_loss = sess.run(fetches=negative_log_likelihood, feed_dict={X: x_sample})
print('Before loop:\n')
print('mu_tf: {}'.format(sess.run(mu_tf)))
print('sigma_tf: {}'.format(sess.run(sigma_tf)))
print('step_mu: {}'.format(sess.run(step_mu)))
print('step_loss: {}\n'.format(step_loss))
sess.run(fetches=loop, feed_dict={X: x_sample})
print('After loop:\n')
print('mu_tf: {}'.format(sess.run(mu_tf)))
print('sigma_tf: {}'.format(sess.run(sigma_tf)))
print('step_mu: {}'.format(sess.run(step_mu)))
print('step_loss: {}'.format(step_loss))