在TensorFlow中,tf.unique
函数可用于返回一维Tensor
的不同元素。如何沿着高维Tensor
的轴0获得不同的子Tensor
?例如,给定以下Tensor
,所需的distinct
函数将返回指定的结果:
input = tf.constant([
[0,3],
[0,1],
[0,4],
[0,1],
[1,5],
[3,9],
[3,2],
[3,6],
[3,5],
[3,3]])
distinct(input) == tf.constant([
[0,3],
[0,1],
[0,4],
[1,5],
[3,9],
[3,2],
[3,6],
[3,5],
[3,3]])
如何为任意数量维度的Tensor
s生成不同的多维元素?
答案 0 :(得分:3)
无保留订单
您可以使用tf.py_function
并调用np.unique
沿轴= 0返回唯一的多维张量。请注意,这会找到唯一的行,但不会保留顺序。
def distinct(a):
_a = np.unique(a, axis=0)
return _a
>> input = tf.constant([
[0,3],
[0,1],
[0,4],
[0,1],
[1,5],
[3,9],
[3,2],
[3,6],
[3,5],
[3,3]])
>> tf.py_function(distinct, [input], tf.int32)
<tf.Tensor: id=940, shape=(9, 2), dtype=int32, numpy=
array([[0, 1],
[0, 3],
[0, 4],
[1, 5],
[3, 2],
[3, 3],
[3, 5],
[3, 6],
[3, 9]], dtype=int32)>
保留订单
def distinct_with_order_preserved(a):
_a = a.numpy()
return pd.DataFrame(_a).drop_duplicates().values
>> tf.py_function(distinct_with_order_preserved, [input], tf.int32)
<tf.Tensor: id=950, shape=(9, 2), dtype=int32, numpy=
array([[0, 3],
[0, 1],
[0, 4],
[1, 5],
[3, 9],
[3, 2],
[3, 6],
[3, 5],
[3, 3]], dtype=int32)>
答案 1 :(得分:0)
一种方法是寻找沿轴0的前一个子Tensor
相等的元素,然后将其过滤掉:
tf.equal
来获取输入的各个轴-1元素与它们自身沿轴0相交的成对相等性。tf.math.reduce_all
聚合成对相等性,直到您拥有输入的轴0元素的二维相等矩阵为止。tf.reduce_any
来查找哪个轴0元素等于以后的任何元素;它们是将要删除的重复项。tf.math.logical_not
和tf.boolean_mask
仅获得轴0的非重复元素。此过程在以下经TensorFlow 2.0 beta测试的Python代码中实现:
def distinct(input:tf.Tensor) -> tf.Tensor:
"""Returns only the distinct sub-Tensors along the 0th dimension of the provided Tensor"""
is_equal = tf.equal(input[:,tf.newaxis], input[tf.newaxis,:])
while len(is_equal.shape) > 2:
is_equal = tf.math.reduce_all(is_equal, axis=2)
all_true = tf.constant(True, shape=is_equal.shape)
true_upper_tri = tf.linalg.band_part(all_true, 0, -1)
false_upper_tri = tf.math.logical_not(true_upper_tri)
is_equal_one_way = tf.math.logical_and(is_equal, false_upper_tri)
is_duplicate = tf.reduce_any(is_equal_one_way, axis=1)
is_distinct = tf.math.logical_not(is_duplicate)
distinct_elements = tf.boolean_mask(input, is_distinct, 0)
return distinct_elements