如何在TensorFlow中查找第一个匹配元素的索引

时间:2017-02-12 05:54:09

标签: tensorflow

我正在寻找一种TensorFlow方法来实现类似于Python的list.index()函数。

给定一个矩阵和一个要查找的值,我想知道矩阵每行中第一次出现的值。

例如,

m is a <batch_size, 100> matrix of integers
val = 23

result = [0] * batch_size
for i, row_elems in enumerate(m):
  result[i] = row_elems.index(val)

我不能认为&#39; val&#39;每行只出现一次,否则我会用tf.argmax(m == val)实现它。在我的情况下,重要的是获得&#39; val&#39;的第一个出现的索引。而不是任何。

4 个答案:

答案 0 :(得分:10)

似乎tf.argmax的作用类似于np.argmax(根据the test),当最多值出现多次时,它会返回第一个索引。 您可以使用tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)来获得所需内容。但是,目前tf.argmax的行为在多次出现最大值时未定义。

如果您担心未定义的行为,可以对tf.argmin的返回值应用tf.where,如@Igor Tsvetkov建议的那样。 例如,

# test with tensorflow r1.0
import tensorflow as tf

val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0  ,   0, val,   0, val],
          [val,   0, val, val,   0],
          [0  , val,   0,   0,   0]]

tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])

with tf.Session() as sess:
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]

请注意,如果某行中不包含tf.segment_min,则InvalidArgumentError会引发val。在row_elems.index(val)不包含row_elems时,您的代码中val也会引发异常。

答案 1 :(得分:3)

看起来有点难看但有效(假设mval都是张量):

idx = list()
for t in tf.unpack(m, axis=0):
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)

修改 如上所述Yaroslav Bulatov,您可以使用tf.map_fn获得相同的结果:

def index1d(t):
    return tf.reduce_min(tf.where(tf.equal(t, val)))

idx = tf.map_fn(index1d, m, dtype=tf.int64)

答案 2 :(得分:1)

这里是解决该问题的另一种方法,假设每一行都有命中率。

import tensorflow as tf

val = 3
m = tf.constant([
    [0  ,   0,   val,   0, val],
    [val,   0,   val, val,   0],
    [0  , val,     0,   0,   0]])

# replace all entries in the matrix either with its column index, or out-of-index-number
match_indices = tf.where(                          # [[5, 5, 2, 5, 4],
    tf.equal(val, m),                              #  [0, 5, 2, 3, 5],
    x=tf.range(tf.shape(m)[1]) * tf.ones_like(m),  #  [5, 1, 5, 5, 5]]
    y=(tf.shape(m)[1])*tf.ones_like(m))

result = tf.reduce_min(match_indices, axis=1)

with tf.Session() as sess:
    print(sess.run(result)) # [2, 0, 1]

答案 3 :(得分:1)

这里是一个解决方案,它也考虑了矩阵不包含元素的情况(来自DeepMind的github存储库的解决方案)

def get_first_occurrence_indices(sequence, eos_idx):
    '''
    args:
        sequence: [batch, length]
        eos_idx: scalar
    '''
    batch_size, maxlen = sequence.get_shape().as_list()
    eos_idx = tf.convert_to_tensor(eos_idx)
    tensor = tf.concat(
            [sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
    index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
    index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
    index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1], 
index_all_occurrences[:, 0])
    index_first_occurrences.set_shape([batch_size])
    index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
    
    return index_first_occurrences

并且:

import tensorflow as tf
mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
idx = 3
first_occurrences = get_first_occurrence_indices(mat, idx)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(first_occurrence) # [3, 2, 1, 5]