我有一个a
类型的张量tf.int64
。我想根据给定的python列表过滤出该张量。
例如-
l = [1,2,3]
a = tf.constant([1,2,3,4], dtype=tf.int64)
需要一个1,2,3
以外的值4
的张量。就是在a
的基础上过滤掉l
。如何在TensorFlow中做到这一点?
答案 0 :(得分:2)
您可以使用tf.sets.set_intersection
:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
# tf.sets.intersection in more recent versions
b = tf.sets.set_intersection(tf.expand_dims(a, 0), tf.expand_dims(l, 0))
b = tf.squeeze(tf.sparse.to_dense(b), 0)
print(sess.run(b))
# [1 2 3]
但是,在许多情况下,这可能无法满足您的要求。如果存在重复的元素,它将丢弃它们,并且也会对输出进行排序。通常,您可以执行以下操作:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)
b = tf.boolean_mask(a, m)
print(sess.run(b))
# [1 2 3]
这是二次比较,但我认为TensorFlow中没有像np.isin
这样更好的东西。