Tensorflow获取张量值的索引

时间:2017-12-11 19:13:14

标签: python matrix tensorflow

给定矩阵和向量,我想在矩阵的相应行中找到值的索引。

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found

理想输出为[2,0],因为y的第一个值1位于m的第一个向量的索引2处。 y,2的第二个值位于m的第二个向量的索引0处。

1 个答案:

答案 0 :(得分:1)

我找到一个解决方案。但我不知道是否有更好的。

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found
y = tf.reshape(y, (y.shape[0], 1))  # [[1], [2]]
cols = tf.where(tf.equal(m, y))[:,-1]  # [2,0]

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(cols))

以上输出:[2, 0]