在tenserflow中显示元素的位置

时间:2017-09-22 14:30:37

标签: python tensorflow

此代码仅显示数组的索引,其中使用了

tensor1 = tf.convert_to_tensor(np.array([1536, 2, 5], dtype='float32'))
tf.where(tensor1 > 3).eval().reshape(1, 2)[0]

输出是:

  

array([0,2],dtype = int64)

我使用索引进行循环打印:

for i in tf.where(tensor1 > 3).eval().reshape(1, 2)[0]:
    print(tensor1[i].eval())

有没有办法没有for循环呢?

1 个答案:

答案 0 :(得分:0)

tf.gather也可用于索引数组,所以

indices = tf.where(tensor1 > 3)
tf.gather(tensor1, indices) 

应该做正确的事