tf.Print()输出减少

时间:2019-08-27 19:58:15

标签: python tensorflow keras

我正在使用tf.Print()调试模型,但是tf.Print()打印得太频繁了。有什么好的策略可以限制此函数的输出,也许每第n个纪元就调用一次?

2 个答案:

答案 0 :(得分:2)

您可以在问题中提到的每个第n个纪元打印输出。这是基本示例:

if epoch % n == 0:
    ...

答案 1 :(得分:2)

您可以在TF变量中保留一个用于步数/历元的计数器,然后在计数器为n的倍数时将tf.condtf.math.mod结合使用来打印张量。

这里是一个例子:

import tensorflow as tf


def print_step(step):
    print_op = tf.print(step)
    with tf.control_dependencies([print_op]):
        out = tf.identity(step)
    return out


step = tf.Variable(0)
print_freq = 100

print_flag = tf.equal(tf.math.mod(step, print_freq), 0)
update_step = tf.assign(step, step + 1)
cond_print_op = tf.cond(print_flag,
                        lambda: print_step(update_step),
                        lambda: update_step)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(1000):
        sess.run(cond_print_op)