Tensorflow:保留张量最大条目的10%

时间:2019-03-05 10:18:55

标签: python tensorflow

我想通过保留最大条目的10%来过滤张量。是否有一个Tensorflow函数来做到这一点?可能的实现会是什么样子?我正在寻找可以处理形状为[N,W,H,C][N,W*H*C]的张量的东西。

通过过滤器,我的意思是张量的形状保持不变,但仅保留最大的10%。因此,除了最大的10%之外,所有条目都变为零。

有可能吗?

2 个答案:

答案 0 :(得分:3)

正确的方法是计算90%,例如使用tf.contrib.distributions.percentile

import tensorflow as tf

images = ... # [N, W, H, C]
n = tf.shape(images)[0]
images_flat = tf.reshape(images, [n, -1])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, 1]),
                        images, tf.zeros_like(images))

如果您想为TensorFlow 2.x(tf.contrib will be removed)做好准备,可以改用TensorFlow Probabilitypercentile函数将在将来永久存在。< / p>

编辑:如果要按通道进行过滤,则可以像下面这样稍微修改代码:

import tensorflow as tf

images = ... # [N, W, H, C]
shape = tf.shape(images)
n, c = shape[0], shape[3]
images_flat = tf.reshape(images, [n, -1, c])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, c]),
                        images, tf.zeros_like(images))

答案 1 :(得分:0)

我还没有找到任何内置方法。尝试以下解决方法:

import numpy as np
import tensorflow as tf

def filter(tensor, ratio):
  num_entries = tf.reduce_prod(tensor.shape)
  num_to_keep = tf.cast(tf.multiply(ratio, tf.cast(num_entries, tf.float32)), tf.int32)
  # Calculate threshold
  x = tf.contrib.framework.sort(tf.reshape(tensor, [num_entries]))
  threshold = x[-num_to_keep]
  # Filter the tensor
  mask = tf.cast(tf.greater_equal(tensor, threshold), tf.float32)
  return tf.multiply(tensor, mask)

tensor = tf.constant(np.arange(40).reshape(2, 4, 5), dtype=tf.float32)
filtered_tensor = filter(tensor, 0.1)
# Print result
tf.InteractiveSession()
print(tensor.eval())
print(filtered_tensor.eval())