tensorflow,如何索引张量的真值?

时间:2017-11-16 06:45:17

标签: python tensorflow

我有一个像张力一样的张量:

SELECT 
    p.id AS id,
    -- sum(a), max(b), min(c) <--- what you want from JobsWithoutPrice table
FROM st_person p inner join (
    (SELECT jobs_without_price.id, -- jobs_without_price.a, jobs_without_price.b, jobs_without_price.c
    FROM st_job jobs_without_price
    WHERE jobs_without_price.price IS NULL
    FOR XML AUTO, ROOT('jobs')) AS JobsWithoutPrice
) on (jobs_without_price.person_id = p.id)
GROUP BY p.id

如何查找所有[false, false, true, false, true, false] 值的索引?

我的解决方案是:

  1. 将其设为1和0值
  2. 使用argmax API查找一个索引,然后将其设置为true / false
  3. 再次使用argmax查找下一个0 / true
  4. 但是这个解决方案并不是那么好。

3 个答案:

答案 0 :(得分:2)

import tensorflow as tf

a = tf.constant([False,False,True,False,True],dtype=tf.bool)
b = tf.where(a)
sess = tf.Session()
print(sess.run(b))

答案 1 :(得分:0)

这是你正在寻找的吗? [k for k, value in enumerate(tensor) if value]

答案 2 :(得分:0)

In [1]: import tensorflow as tf                                                                                                                                                                                      

In [2]: a = tf.constant([False, False, True, True])

In [3]: a_n = [tf.cond(tf.equal(v, tf.constant(True)), lambda: tf.constant(k), lambda: tf.constant(-1)) for k, v in enumerate(tf.unstack(a))]                                                                        

In [4]: sess = tf.Session()

In [5]: sess.run(a_n)                                                                                                                                                                                                
Out[5]: [-1, -1, 2, 3]

希望这会有所帮助......