我有一个这样的张量:
tf_a2 = tf.constant([[1, 2, 5 ],
[1, 4, 6 ],
[0, 10, 10],
[2, 4, 6 ],
[2, 4, 10]])
我想在此矩阵中找到重复n
次以上的整个索引。
例如:1
被重复two times
。 2
被重复three times
。 5
被重复one time
。考虑行之间的重复。另外,我想完全跳过数字10
(10个常数)。
在这里n=2
,所以结果看起来像:因为2 and 4
的重复次数超过了two times
。
[[0, 2, 0 ],
[0, 4, 0 ],
[0, 0, 0 ],
[2, 4, 0 ],
[2, 4, 0 ]]
我找到了一个示例here,但说明是针对Matlab代码的。
提前感谢:)
答案 0 :(得分:1)
首先,您可以使用tf.unique_with_counts
查找一维张量中的唯一元素。
import tensorflow as tf
tf_a2 = tf.constant([[1, 2, 5 ],
[1, 4, 6 ],
[0, 10, 10],
[2, 4, 6 ],
[2, 4, 10]])
n = 2
constant = 10
y, idx, count = tf.unique_with_counts(tf.reshape(tf_a2,[-1,]))
# y = [ 1 2 5 4 6 0 10]
# idx = [0 1 2 0 3 4 5 6 6 1 3 4 1 3 6]
# count = [2 3 1 3 2 1 3]
然后您可以将重复时间映射到原始张量。
count_mask = tf.reshape(tf.gather(count,idx),tf_a2.shape)
# [[2 3 1]
# [2 3 2]
# [1 3 3]
# [3 3 2]
# [3 3 3]]
最后,您可以跳过数字10
,并在tf.where
之前得到期望的结果。
# skip constant and filter n time
result = tf.where(tf.logical_and(tf.greater(count_mask,n),
tf.not_equal(tf_a2,constant)),
tf_a2,
tf.zeros_like(tf_a2))
with tf.Session() as sess:
print(sess.run(result))
# [[0 2 0]
# [0 4 0]
# [0 0 0]
# [2 4 0]
# [2 4 0]]