我正在使用tf.Print()
调试模型,但是tf.Print()
打印得太频繁了。有什么好的策略可以限制此函数的输出,也许每第n个纪元就调用一次?
答案 0 :(得分:2)
您可以在问题中提到的每个第n个纪元打印输出。这是基本示例:
if epoch % n == 0:
...
答案 1 :(得分:2)
您可以在TF变量中保留一个用于步数/历元的计数器,然后在计数器为n
的倍数时将tf.cond与tf.math.mod结合使用来打印张量。
这里是一个例子:
import tensorflow as tf
def print_step(step):
print_op = tf.print(step)
with tf.control_dependencies([print_op]):
out = tf.identity(step)
return out
step = tf.Variable(0)
print_freq = 100
print_flag = tf.equal(tf.math.mod(step, print_freq), 0)
update_step = tf.assign(step, step + 1)
cond_print_op = tf.cond(print_flag,
lambda: print_step(update_step),
lambda: update_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1000):
sess.run(cond_print_op)