如何在张量流中找到矩阵行的公共值

时间:2019-07-04 21:26:04

标签: python tensorflow slice

我有一个这样的张量:

tf_a2 = tf.constant([[1, 2,  5 ],
                     [1, 4,  6 ],
                     [0, 10, 10],
                     [2, 4,  6 ],
                     [2, 4,  10]])

我想在此矩阵中找到重复n次以上的整个索引。

例如:1被重复two times2被重复three times5被重复one time。考虑行之间的重复。另外,我想完全跳过数字10(10个常数)。

在这里n=2,所以结果看起来像:因为2 and 4的重复次数超过了two times

                    [[0, 2,  0 ],
                     [0, 4,  0 ],
                     [0, 0,  0 ],
                     [2, 4,  0 ],
                     [2, 4,  0 ]]

我找到了一个示例here,但说明是针对Matlab代码的。

提前感谢:)

1 个答案:

答案 0 :(得分:1)

首先,您可以使用tf.unique_with_counts查找一维张量中的唯一元素。

import tensorflow as tf

tf_a2 = tf.constant([[1, 2,  5 ],
                     [1, 4,  6 ],
                     [0, 10, 10],
                     [2, 4,  6 ],
                     [2, 4,  10]])
n = 2
constant = 10

y, idx, count = tf.unique_with_counts(tf.reshape(tf_a2,[-1,]))
# y = [ 1  2  5  4  6  0 10]
# idx = [0 1 2 0 3 4 5 6 6 1 3 4 1 3 6]
# count = [2 3 1 3 2 1 3]

然后您可以将重复时间映射到原始张量。

count_mask = tf.reshape(tf.gather(count,idx),tf_a2.shape)
# [[2 3 1]
#  [2 3 2]
#  [1 3 3]
#  [3 3 2]
#  [3 3 3]]

最后,您可以跳过数字10,并在tf.where之前得到期望的结果。

# skip constant and filter n time
result = tf.where(tf.logical_and(tf.greater(count_mask,n),
                                 tf.not_equal(tf_a2,constant)),
                  tf_a2,
                  tf.zeros_like(tf_a2))

with tf.Session() as sess:
    print(sess.run(result))

# [[0 2 0]
#  [0 4 0]
#  [0 0 0]
#  [2 4 0]
#  [2 4 0]]