在tensorflow中如何返回张量中第一个匹配值的索引

时间:2017-07-20 12:05:45

标签: python tensorflow

我必须用tensorflow的语法编写代码。这是我期望的功能:

# read the first matched datum's index in a vector
index = next(idx for idx, value in enumerate(ladder) if value >= 0.99 ) + 1

在张量流中,我试过

a = tf.greater_equal(ladder, 0.99)
b = tf.where(a)         # It'll returns all the indices of the matched data

但我不知道如何继续,因为我无法访问张量的值。

所以我的问题是如何在sess.run之前获得第一个匹配数据的索引?

2 个答案:

答案 0 :(得分:0)

试试下面的示例。在张量流中,您首先构建一个节点图,显示应该执行的计算,然后在会话的上下文中执行此图(它保存图的状态)以实际执行计算。

import tensorflow as tf

data = [ 1, 2, 3, 4, 5, 4, 3, 2, 1]
data_input = tf.constant(data)

indexes = tf.where(tf.greater(data, 3))

with tf.Session() as sess:
    print(sess.run(indexes))

# Return [[3], [4], [5]]

以下是tf.Session的文档,但图表和会话的概念也在tutorial

的第一部分进行了解释

答案 1 :(得分:0)

如您所述,在下面的代码中

b

您无法访问b的值。

您可以访问b的值。在此之前,您应该知道b是二维矩阵。

由于ladder[b[0, 0]]再次是张量,因此您可以在张量流函数中使用它。您无法在没有会话的情况下直接访问它,就像您在会话之前无法像val = tf.gather(ladder, b[0, 0])那样编码。

但是如果你想执行相同的操作,那么你必须像下面那样编码

{{1}}

同样,您可以使用b。

的值