我正在寻找一种TensorFlow方法来实现类似于Python的list.index()函数。
给定一个矩阵和一个要查找的值,我想知道矩阵每行中第一次出现的值。
例如,
m is a <batch_size, 100> matrix of integers
val = 23
result = [0] * batch_size
for i, row_elems in enumerate(m):
result[i] = row_elems.index(val)
我不能认为&#39; val&#39;每行只出现一次,否则我会用tf.argmax(m == val)实现它。在我的情况下,重要的是获得&#39; val&#39;的第一个出现的索引。而不是任何。
答案 0 :(得分:10)
似乎tf.argmax
的作用类似于np.argmax
(根据the test),当最多值出现多次时,它会返回第一个索引。
您可以使用tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)
来获得所需内容。但是,目前tf.argmax
的行为在多次出现最大值时未定义。
如果您担心未定义的行为,可以对tf.argmin
的返回值应用tf.where
,如@Igor Tsvetkov建议的那样。
例如,
# test with tensorflow r1.0
import tensorflow as tf
val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]]
tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
with tf.Session() as sess:
print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
请注意,如果某行中不包含tf.segment_min
,则InvalidArgumentError
会引发val
。在row_elems.index(val)
不包含row_elems
时,您的代码中val
也会引发异常。
答案 1 :(得分:3)
看起来有点难看但有效(假设m
和val
都是张量):
idx = list()
for t in tf.unpack(m, axis=0):
idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)
修改强>
如上所述Yaroslav Bulatov,您可以使用tf.map_fn
获得相同的结果:
def index1d(t):
return tf.reduce_min(tf.where(tf.equal(t, val)))
idx = tf.map_fn(index1d, m, dtype=tf.int64)
答案 2 :(得分:1)
这里是解决该问题的另一种方法,假设每一行都有命中率。
import tensorflow as tf
val = 3
m = tf.constant([
[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]])
# replace all entries in the matrix either with its column index, or out-of-index-number
match_indices = tf.where( # [[5, 5, 2, 5, 4],
tf.equal(val, m), # [0, 5, 2, 3, 5],
x=tf.range(tf.shape(m)[1]) * tf.ones_like(m), # [5, 1, 5, 5, 5]]
y=(tf.shape(m)[1])*tf.ones_like(m))
result = tf.reduce_min(match_indices, axis=1)
with tf.Session() as sess:
print(sess.run(result)) # [2, 0, 1]
答案 3 :(得分:1)
这里是一个解决方案,它也考虑了矩阵不包含元素的情况(来自DeepMind的github存储库的解决方案)
def get_first_occurrence_indices(sequence, eos_idx):
'''
args:
sequence: [batch, length]
eos_idx: scalar
'''
batch_size, maxlen = sequence.get_shape().as_list()
eos_idx = tf.convert_to_tensor(eos_idx)
tensor = tf.concat(
[sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1],
index_all_occurrences[:, 0])
index_first_occurrences.set_shape([batch_size])
index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
return index_first_occurrences
并且:
import tensorflow as tf
mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
idx = 3
first_occurrences = get_first_occurrence_indices(mat, idx)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(first_occurrence) # [3, 2, 1, 5]