根据TensorFlow中的python列表过滤张量

时间:2019-07-12 11:48:09

标签: python tensorflow

我有一个a类型的张量tf.int64。我想根据给定的python列表过滤出该张量。
例如-

l = [1,2,3]
a = tf.constant([1,2,3,4], dtype=tf.int64) 

需要一个1,2,3以外的值4的张量。就是在a的基础上过滤掉l。如何在TensorFlow中做到这一点?

1 个答案:

答案 0 :(得分:2)

您可以使用tf.sets.set_intersection

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    l = tf.constant([1, 2, 3], dtype=tf.int64)
    a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
    # tf.sets.intersection in more recent versions
    b = tf.sets.set_intersection(tf.expand_dims(a, 0), tf.expand_dims(l, 0))
    b = tf.squeeze(tf.sparse.to_dense(b), 0)
    print(sess.run(b))
    # [1 2 3]

但是,在许多情况下,这可能无法满足您的要求。如果存在重复的元素,它将丢弃它们,并且也会对输出进行排序。通常,您可以执行以下操作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    l = tf.constant([1, 2, 3], dtype=tf.int64)
    a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
    m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)
    b = tf.boolean_mask(a, m)
    print(sess.run(b))
    # [1 2 3]

这是二次比较,但我认为TensorFlow中没有像np.isin这样更好的东西。