我应该如何判断张量流Tensor中的所有数字是0还是1?
bad_mask = tf.Variable([[0.0,1.0,0.2,0.0,0.0], [0.0,5.0,0.0,2.3,0.0]])
good_mask = tf.Variable([[0.0,1.0,1.0,0.0,0.0], [0.0,1.0,0.0,1.0,0.0]])
我想使用tf.assert
。
答案 0 :(得分:2)
喜欢这段代码(已测试):
import tensorflow as tf
bad_mask = tf.Variable([[0.0,1.0,0.2,0.0,0.0], [0.0,5.0,0.0,2.3,0.0]])
good_mask = tf.Variable([[0.0,1.0,1.0,0.0,0.0], [0.0,1.0,0.0,1.0,0.0]])
x = tf.Assert( tf.reduce_all(
tf.logical_or( tf.equal( good_mask, 0.0 ), tf.equal( good_mask, 1.0 ) )
), [ good_mask ] )
y = tf.Assert( tf.reduce_all(
tf.logical_or( tf.equal( bad_mask, 0.0 ), tf.equal( bad_mask, 1.0 ) )
), [ bad_mask ] )
with tf.Session() as sess:
sess.run( tf.global_variables_initializer() )
print( sess.run ( x ) )
print( sess.run ( y ) )
将输出:
无
和
InvalidArgumentError:断言失败:[[0 1 0.2] ...]
[[Node:Assert_4 / AssertGuard / Assert = Assert [T = [DT_FLOAT],summarize = 3,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](Assert_4 / AssertGuard /断言/切换,Assert_4 / AssertGuard / Assert / Switch_1)]]
根据需要。
答案 1 :(得分:0)
使用tf.unstack
将张量转换为列表,然后检查所有值是0还是1.