在Tensorflow中设置交集

时间:2018-11-16 22:22:06

标签: python tensorflow intersection set-intersection

我要检查稀疏张量中是否包含一组给定值。稀疏张量称为labels,只有一维包含ID列表。

最后,这似乎是一个简单的集合交集问题,所以我尝试了这个。

sparse_ids = load_ids_as_sparse_tensor()
wanted_ids = tf.constant([34, 56, 12])
intersection = tf.sets.set_intersection(
    wanted_ids,
    tf.cast(sparse_ids.values, tf.int32)
)
contains_any_wanted_ids = tf.not_equal(tf.size(intersection), 0)

但是,我遇到此错误:

ValueError: Shape must be at least rank 2 but is rank 1 for 'DenseToDenseSetOperation' (op: 'DenseToDenseSetOperation') with input shapes: [3], [?].

有什么想法吗?

1 个答案:

答案 0 :(得分:1)

以下代码有效。但是,我不确定结果是否是您想要的。

import tensorflow as tf
a = tf.constant([34, 56, 12])
b = tf.constant([56])
intersection = tf.sets.set_intersection(a[None,:],b[None,:])
sess=tf.Session()
sess.run(intersection)

输出:

  

SparseTensorValue(索引=数组([[0,0]],dtype = int64),   values = array([56]),density_shape = array([1,1],dtype = int64))