我想通过保留最大条目的10%来过滤张量。是否有一个Tensorflow函数来做到这一点?可能的实现会是什么样子?我正在寻找可以处理形状为[N,W,H,C]
和[N,W*H*C]
的张量的东西。
通过过滤器,我的意思是张量的形状保持不变,但仅保留最大的10%。因此,除了最大的10%之外,所有条目都变为零。
有可能吗?
答案 0 :(得分:3)
正确的方法是计算90%,例如使用tf.contrib.distributions.percentile
:
import tensorflow as tf
images = ... # [N, W, H, C]
n = tf.shape(images)[0]
images_flat = tf.reshape(images, [n, -1])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, 1]),
images, tf.zeros_like(images))
如果您想为TensorFlow 2.x(tf.contrib
will be removed)做好准备,可以改用TensorFlow Probability,percentile
函数将在将来永久存在。< / p>
编辑:如果要按通道进行过滤,则可以像下面这样稍微修改代码:
import tensorflow as tf
images = ... # [N, W, H, C]
shape = tf.shape(images)
n, c = shape[0], shape[3]
images_flat = tf.reshape(images, [n, -1, c])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, c]),
images, tf.zeros_like(images))
答案 1 :(得分:0)
我还没有找到任何内置方法。尝试以下解决方法:
import numpy as np
import tensorflow as tf
def filter(tensor, ratio):
num_entries = tf.reduce_prod(tensor.shape)
num_to_keep = tf.cast(tf.multiply(ratio, tf.cast(num_entries, tf.float32)), tf.int32)
# Calculate threshold
x = tf.contrib.framework.sort(tf.reshape(tensor, [num_entries]))
threshold = x[-num_to_keep]
# Filter the tensor
mask = tf.cast(tf.greater_equal(tensor, threshold), tf.float32)
return tf.multiply(tensor, mask)
tensor = tf.constant(np.arange(40).reshape(2, 4, 5), dtype=tf.float32)
filtered_tensor = filter(tensor, 0.1)
# Print result
tf.InteractiveSession()
print(tensor.eval())
print(filtered_tensor.eval())