如何删除数组中的行包含小于11的值

时间:2018-08-18 15:03:28

标签: python-3.x tensorflow

以下示例是普通的Python代码。但是如何在Tensorflow库中做到这一点?我要删除数组中包含小于11的值的行。我希望此编码仅出于准确预测而计算准确性。

a = np.array([[ 0,  1,  2,  0,  4,  5,  6,  7,  8,  10],
              [ 0, 11,  0, 13,  0, 15,  0, 17, 18,  0]])
print (a[a.max(axis=1) >= 11])

1 个答案:

答案 0 :(得分:0)

使用tf.reduce_max来计算沿轴的最大值,并使用tf.boolean_mask根据布尔条件将张量子集化:

import tensorflow as tf
tf.InteractiveSession()

t = tf.constant(a)

t1 = tf.boolean_mask(t, tf.reduce_max(t, axis=1) >= 11)

t1.eval()
# array([[ 0, 11,  0, 13,  0, 15,  0, 17, 18,  0]])