如果我的任何模型变量是NaN,我如何检查TensorFlow?

时间:2018-01-08 17:25:39

标签: python loops tensorflow nan

我的模特是:

weights = {
    'h1': tf.Variable(tf.random_normal([num_input, num_hidden_1]),name="h1"),
    'h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2]),name="h2"),
    'h3': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_3]),name="h3"),
    'wout': tf.Variable(tf.random_normal([num_hidden_3, num_output]),name="wout")
}

biases = {
    'b1': tf.Variable(tf.random_normal([num_hidden_1]),name="b1"),
    'b2': tf.Variable(tf.random_normal([num_hidden_2]),name="b2"),
    'b3': tf.Variable(tf.random_normal([num_hidden_3]),name="b3"),
    'bout': tf.Variable(tf.random_normal([num_output]),name="bout")
}

我知道如何使用tf.is_nan检查单个元素是否为NaN:

def save_weights():
    if sess.run(tf.is_nan(sess.run(weights["h1"]).tolist()[0][0])):
        utils.printflush("weights have nan, refused to save")

当然,我可以循环遍历Python中的所有元素,但是对于一百万个权重和偏差,这是非常耗时的。是否有可以执行此操作的TensorFlow操作(迭代所有模型变量)?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.reduce_any检查张量中是否至少有一个元素不为零。

另一方面,您可能需要查看tf.verify_tensor_all_finite,如果在张量上找到NaN或无限值,将会中断执行(请阅读tf.Assert的文档以了解如何使用图表上的断言)。

在任何情况下,请记住所有这些功能(包括tf.is_nan)都会在图表中创建新操作。最好先创建所有操作,然后根据需要调用run,以确保每次运行时图形不会不必要地增长。