在张量流中查找值张量到另一个张量的索引

时间:2020-03-24 19:30:39

标签: tensorflow

我有2个向量

a = [0 0 37 7 8 0 0] b = [0 0 4 37 8]

我想找到b的值到a的索引,所以输出看起来像

c = [0 0 -1 2 4]

我该如何在Tensorflow操作中做到这一点

1 个答案:

答案 0 :(得分:1)

解决方案:

import tensorflow as tf

a = tf.constant([0, 0, 37, 7, 8, 0, 0])
b = tf.constant([0, 0, 4, 37, 8])

expanded_b = b[..., None]
tiled_a = tf.tile(a[None, ...], [tf.shape(b)[0], 1])
mult = tf.cast(tf.equal(expanded_b, tiled_a), tf.float32)
sub = tf.cast(tf.math.equal(tf.reduce_sum(mult, -1), 0), tf.int64)
res = tf.argmax(mult, axis=-1) - sub

with tf.Session() as sess:
    print(res.eval()) # [ 0  0 -1  2  4]

说明:

a = tf.constant([0, 0, 37, 7, 8, 0, 0])
b = tf.constant([0, 0, 4, 37, 8])

expanded_b = b[..., None]
# expanded_b:
# [[ 0]
#  [ 0]
#  [ 4]
#  [37]
#  [ 8]]
tiled_a = tf.tile(a[None, ...], [tf.shape(b)[0], 1])
# tiled_a
# [[ 0  0 37  7  8  0  0]
#  [ 0  0 37  7  8  0  0]
#  [ 0  0 37  7  8  0  0]
#  [ 0  0 37  7  8  0  0]
#  [ 0  0 37  7  8  0  0]]

# Now expanded_b and tiled_a are broadcastable so we can compare
# each element of b to all elements in a in parallel
mult = tf.cast(tf.equal(expanded_b, tiled_a), tf.float32)
# mult
# [[1. 1. 0. 0. 0. 1. 1.]
#  [1. 1. 0. 0. 0. 1. 1.]
#  [0. 0. 0. 0. 0. 0. 0.]
#  [0. 0. 1. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 1. 0. 0.]]

# from mult we need first index from axis -1 that is != 0 (using argmax)
# sub shows which rows have all zeros (no element of b in a)
# for such rows we put value 1
sub = tf.cast(tf.math.equal(tf.reduce_sum(mult, -1), 0), tf.int64) 
# sub
# [0 0 1 0 0]
# result
res = tf.argmax(mult, axis=-1) - sub