tensorflow case error:无效参数:断言失败:没有条件评估为True

时间:2018-01-05 16:25:34

标签: python tensorflow

评估此代码:

import tensorflow as tf

def example(trans_type):
    trans_type = tf.Print(trans_type, [trans_type, tf.equal(trans_type, 2)], "trans_type")

    au = tf.case({tf.equal(trans_type, 0): lambda: tf.constant(0),
                  tf.equal(trans_type, 1): lambda: tf.constant(1),
                  tf.equal(trans_type, 2): lambda: tf.constant(2),
                  tf.equal(trans_type, 3): lambda: tf.constant(3),
                  tf.equal(trans_type, 4): lambda: tf.constant(4)
                  }, exclusive=True)

    return au


if __name__ == '__main__':

    tr_data = tf.data.Dataset.from_tensor_slices([2])
    tr_data = tr_data.map(example)
    tr_data = tr_data.repeat().batch(1)

    it_x = tr_data.make_one_shot_iterator()

    with tf.Session() as sess:
        sess.run(it_x.make_initializer(tr_data))
        r = sess.run(it_x.get_next())   # this one works however: r = sess.run(example(2))
        print(str(r))

导致以下错误:

2018-01-05 19:06:49.022920: I tensorflow/core/kernels/logging_ops.cc:79] trans_type[2][1]
2018-01-05 19:06:49.023131: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: assertion failed: [None of the conditions evaluated as True. Conditions: (Equal_1:0, Equal_2:0, Equal_3:0, Equal_4:0, Equal_5:0), Values:] [0 0 1 0 0]
 [[Node: case/If_0/Assert_1/AssertGuard/Assert = Assert[T=[DT_STRING, DT_BOOL], summarize=5](case/If_0/Assert_1/AssertGuard/Assert/Switch, case/If_0/Assert_1/AssertGuard/Assert/data_0, case/If_0/Assert_1/AssertGuard/Assert/Switch_1)]]

虽然tf.Print语句清楚地打印trans_type = 2,甚至表明tf.equal的计算结果为true。这里有什么问题,更重要的是 - 当tf.Print语句给出令人困惑的结果时,如何调试它?

编辑:改为最小的例子。它似乎与数据集/迭代器有关,只是调用该方法可以正常工作。

0 个答案:

没有答案