考虑以下代码:
import tensorflow as tf
global_step = tf.train.create_global_step()
x = tf.Variable(100.0)
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(x, global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step, _ = sess.run([global_step, train_op])
print(step)
我得到的输出是'1',但是我认为没有什么能阻止Tensorflow给我'0',即,在train_op
中增加它的'assign'op之前的全局步骤变量的值。实际上,我还有另一个更复杂的Tensorflow程序,表现出此行为,其中我从Session.run([global_step, train_op])
获得的全局步长值在运行它的两台机器之间是一一对应的。
对于全局step变量,如何确定从train_op
之前获取其值,或者如何从train_op
之后获取其值?
我知道我可以在sess.run([global_step])
之前或之后分别进行sess.run([train_op])
,但是如果不涉及过度复杂的代码,我想在单个session.run()
中做尽可能多的事情。我知道我可以通过将global_step
分配给另一个变量并在tf.assign
和train_op
之间建立控件依赖性来获得预增值:
import tensorflow as tf
global_step = tf.train.create_global_step()
global_step2 = tf.get_variable('step-mirror', dtype=global_step.dtype,
shape=global_step.shape)
global_step2 = tf.assign(global_step2, global_step)
x = tf.Variable(100.0)
optimizer = tf.train.AdamOptimizer()
with tf.control_dependencies([global_step2]):
train_op = optimizer.minimize(x, global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step, _ = sess.run([global_step2, train_op])
print(step)
但是我正在寻找一种更简单的方法,也许是我缺少Tensorflow功能,因为它在指定任何变量之前都没有指定评估变量。
编辑:响应于此comment,此操作不起作用,并且打印为“ 1”而不是“ 0”:
import tensorflow as tf
global_step = tf.train.create_global_step()
x = tf.Variable(100.0)
optimizer = tf.train.AdamOptimizer()
with tf.control_dependencies([global_step]):
train_op = optimizer.minimize(x, global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step, _ = sess.run([global_step, train_op])
print(step)
答案 0 :(得分:1)
您可以使用它来读取train_op之后的全局步骤:
import tensorflow as tf
global_step = tf.train.create_global_step()
x = tf.Variable(100.0)
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(x, global_step=global_step)
with tf.control_dependencies([train_op]):
global_step_value = global_step.read_value()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step, _ = sess.run([global_step_value, train_op])
print(step)
这里global_step_value
不再是变量。计算global_step
之后,它是一个train_op
值的张量。 here在“使用变量”下进行了说明。