我有2个向量
a = [0 0 37 7 8 0 0]
b = [0 0 4 37 8]
我想找到b的值到a的索引,所以输出看起来像
c = [0 0 -1 2 4]
我该如何在Tensorflow操作中做到这一点
答案 0 :(得分:1)
解决方案:
import tensorflow as tf
a = tf.constant([0, 0, 37, 7, 8, 0, 0])
b = tf.constant([0, 0, 4, 37, 8])
expanded_b = b[..., None]
tiled_a = tf.tile(a[None, ...], [tf.shape(b)[0], 1])
mult = tf.cast(tf.equal(expanded_b, tiled_a), tf.float32)
sub = tf.cast(tf.math.equal(tf.reduce_sum(mult, -1), 0), tf.int64)
res = tf.argmax(mult, axis=-1) - sub
with tf.Session() as sess:
print(res.eval()) # [ 0 0 -1 2 4]
说明:
a = tf.constant([0, 0, 37, 7, 8, 0, 0])
b = tf.constant([0, 0, 4, 37, 8])
expanded_b = b[..., None]
# expanded_b:
# [[ 0]
# [ 0]
# [ 4]
# [37]
# [ 8]]
tiled_a = tf.tile(a[None, ...], [tf.shape(b)[0], 1])
# tiled_a
# [[ 0 0 37 7 8 0 0]
# [ 0 0 37 7 8 0 0]
# [ 0 0 37 7 8 0 0]
# [ 0 0 37 7 8 0 0]
# [ 0 0 37 7 8 0 0]]
# Now expanded_b and tiled_a are broadcastable so we can compare
# each element of b to all elements in a in parallel
mult = tf.cast(tf.equal(expanded_b, tiled_a), tf.float32)
# mult
# [[1. 1. 0. 0. 0. 1. 1.]
# [1. 1. 0. 0. 0. 1. 1.]
# [0. 0. 0. 0. 0. 0. 0.]
# [0. 0. 1. 0. 0. 0. 0.]
# [0. 0. 0. 0. 1. 0. 0.]]
# from mult we need first index from axis -1 that is != 0 (using argmax)
# sub shows which rows have all zeros (no element of b in a)
# for such rows we put value 1
sub = tf.cast(tf.math.equal(tf.reduce_sum(mult, -1), 0), tf.int64)
# sub
# [0 0 1 0 0]
# result
res = tf.argmax(mult, axis=-1) - sub