标签: tensorflow tensor
我想创建一个k稀疏CAE,在该CAE中,每个图像在最后一个编码器图像中仅保留k个最高特征图。因此,我需要找出哪个频道是最高的。
我目前的方法是使用tf.nn.top_k(f_maps_mean, k_real),其中f_maps是所有通道的均值:tf.reduce_mean(inp_tensor, [1, 2])。 k_real只是k%*特征图数量的舍入整数。
tf.nn.top_k(f_maps_mean, k_real)
tf.reduce_mean(inp_tensor, [1, 2])
然后,我想通过这些索引或bool蒙版剪切张量,但无法将批处理大小归因于此操作。
如果有人知道更直接的方法,我也将不胜感激。