TensorFlow-返回多维张量的不同子张量

时间:2019-07-19 16:03:51

标签: python tensorflow tensorflow2.0

在TensorFlow中,tf.unique函数可用于返回一维Tensor的不同元素。如何沿着高维Tensor的轴0获得不同的子Tensor?例如,给定以下Tensor,所需的distinct函数将返回指定的结果:

input = tf.constant([
    [0,3],
    [0,1],
    [0,4],
    [0,1],
    [1,5],
    [3,9],
    [3,2],
    [3,6],
    [3,5],
    [3,3]])

distinct(input) == tf.constant([
    [0,3],
    [0,1],
    [0,4],
    [1,5],
    [3,9],
    [3,2],
    [3,6],
    [3,5],
    [3,3]])

如何为任意数量维度的Tensor s生成不同的多维元素?

2 个答案:

答案 0 :(得分:3)

无保留订单

您可以使用tf.py_function并调用np.unique沿轴= 0返回唯一的多维张量。请注意,这会找到唯一的行,但不会保留顺序。

def distinct(a):
    _a =  np.unique(a, axis=0)
    return _a

>> input = tf.constant([
[0,3],
[0,1],
[0,4],
[0,1],
[1,5],
[3,9],
[3,2],
[3,6],
[3,5],
[3,3]])

>> tf.py_function(distinct, [input], tf.int32)
<tf.Tensor: id=940, shape=(9, 2), dtype=int32, numpy=
array([[0, 1],
   [0, 3],
   [0, 4],
   [1, 5],
   [3, 2],
   [3, 3],
   [3, 5],
   [3, 6],
   [3, 9]], dtype=int32)>

保留订单

def distinct_with_order_preserved(a):
    _a = a.numpy()
    return pd.DataFrame(_a).drop_duplicates().values

>> tf.py_function(distinct_with_order_preserved, [input], tf.int32)
<tf.Tensor: id=950, shape=(9, 2), dtype=int32, numpy=
array([[0, 3],
   [0, 1],
   [0, 4],
   [1, 5],
   [3, 9],
   [3, 2],
   [3, 6],
   [3, 5],
   [3, 3]], dtype=int32)>

答案 1 :(得分:0)

一种方法是寻找沿轴0的前一个子Tensor相等的元素,然后将其过滤掉:

  1. 使用tf.equal来获取输入的各个轴-1元素与它们自身沿轴0相交的成对相等性。
  2. 使用tf.math.reduce_all聚合成对相等性,直到您拥有输入的轴0元素的二维相等矩阵为止。
  3. 生成False值的上三角矩阵
  4. 使用该三角形矩阵将相等性比较限制为沿轴0的一个方向。
  5. 使用tf.reduce_any来查找哪个轴0元素等于以后的任何元素;它们是将要删除的重复项。
  6. 使用tf.math.logical_nottf.boolean_mask仅获得轴0的非重复元素。

此过程在以下经TensorFlow 2.0 beta测试的Python代码中实现:

def distinct(input:tf.Tensor) -> tf.Tensor:
    """Returns only the distinct sub-Tensors along the 0th dimension of the provided Tensor"""
    is_equal = tf.equal(input[:,tf.newaxis], input[tf.newaxis,:])
    while len(is_equal.shape) > 2:
        is_equal = tf.math.reduce_all(is_equal, axis=2)
    all_true = tf.constant(True, shape=is_equal.shape)
    true_upper_tri = tf.linalg.band_part(all_true, 0, -1)
    false_upper_tri = tf.math.logical_not(true_upper_tri)
    is_equal_one_way = tf.math.logical_and(is_equal, false_upper_tri)
    is_duplicate = tf.reduce_any(is_equal_one_way, axis=1)
    is_distinct = tf.math.logical_not(is_duplicate)
    distinct_elements = tf.boolean_mask(input, is_distinct, 0)
    return distinct_elements