我应该如何检查张量流Tensor中的所有数字是否都是二进制数

时间:2018-05-10 19:45:46

标签: python tensorflow

我应该如何判断张量流Tensor中的所有数字是0还是1?

bad_mask = tf.Variable([[0.0,1.0,0.2,0.0,0.0], [0.0,5.0,0.0,2.3,0.0]])
good_mask = tf.Variable([[0.0,1.0,1.0,0.0,0.0], [0.0,1.0,0.0,1.0,0.0]])

我想使用tf.assert

2 个答案:

答案 0 :(得分:2)

喜欢这段代码(已测试):

import tensorflow as tf

bad_mask = tf.Variable([[0.0,1.0,0.2,0.0,0.0], [0.0,5.0,0.0,2.3,0.0]])
good_mask = tf.Variable([[0.0,1.0,1.0,0.0,0.0], [0.0,1.0,0.0,1.0,0.0]])

x = tf.Assert( tf.reduce_all( 
    tf.logical_or( tf.equal( good_mask, 0.0 ), tf.equal( good_mask, 1.0 ) )
), [ good_mask ] )

y = tf.Assert( tf.reduce_all( 
    tf.logical_or( tf.equal( bad_mask, 0.0 ), tf.equal( bad_mask, 1.0 ) )
), [ bad_mask ] )

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer() )
    print( sess.run ( x ) )
    print( sess.run ( y ) )

将输出:

  

  

InvalidArgumentError:断言失败:[[0 1 0.2] ...]
  [[Node:Assert_4 / AssertGuard / Assert = Assert [T = [DT_FLOAT],summarize = 3,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](Assert_4 / AssertGuard /断言/切换,Assert_4 / AssertGuard / Assert / Switch_1)]]

根据需要。

答案 1 :(得分:0)

使用tf.unstack将张量转换为列表,然后检查所有值是0还是1.