评估此代码:
import tensorflow as tf
def example(trans_type):
trans_type = tf.Print(trans_type, [trans_type, tf.equal(trans_type, 2)], "trans_type")
au = tf.case({tf.equal(trans_type, 0): lambda: tf.constant(0),
tf.equal(trans_type, 1): lambda: tf.constant(1),
tf.equal(trans_type, 2): lambda: tf.constant(2),
tf.equal(trans_type, 3): lambda: tf.constant(3),
tf.equal(trans_type, 4): lambda: tf.constant(4)
}, exclusive=True)
return au
if __name__ == '__main__':
tr_data = tf.data.Dataset.from_tensor_slices([2])
tr_data = tr_data.map(example)
tr_data = tr_data.repeat().batch(1)
it_x = tr_data.make_one_shot_iterator()
with tf.Session() as sess:
sess.run(it_x.make_initializer(tr_data))
r = sess.run(it_x.get_next()) # this one works however: r = sess.run(example(2))
print(str(r))
导致以下错误:
2018-01-05 19:06:49.022920: I tensorflow/core/kernels/logging_ops.cc:79] trans_type[2][1]
2018-01-05 19:06:49.023131: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: assertion failed: [None of the conditions evaluated as True. Conditions: (Equal_1:0, Equal_2:0, Equal_3:0, Equal_4:0, Equal_5:0), Values:] [0 0 1 0 0]
[[Node: case/If_0/Assert_1/AssertGuard/Assert = Assert[T=[DT_STRING, DT_BOOL], summarize=5](case/If_0/Assert_1/AssertGuard/Assert/Switch, case/If_0/Assert_1/AssertGuard/Assert/data_0, case/If_0/Assert_1/AssertGuard/Assert/Switch_1)]]
虽然tf.Print语句清楚地打印trans_type = 2,甚至表明tf.equal的计算结果为true。这里有什么问题,更重要的是 - 当tf.Print语句给出令人困惑的结果时,如何调试它?
编辑:改为最小的例子。它似乎与数据集/迭代器有关,只是调用该方法可以正常工作。