如何扩展theanos downsample.max_pool_2d_same_size,以便不仅在特征映射中汇集,而且还以有效的方式汇集在这些映射之间?
让我们说我有3张特征地图,每张都是10x10,这将是4D Tensor(1,3,10,10)。首先让(10,10)特征映射的每个最大池((2,2),不重叠)。结果是3个稀疏特征映射,仍然是(10,10)但是大多数值等于零:在(2,2)窗口内最多一个值大于零。这就是downsample.max_pool_2d_same_size的作用。
接下来,我想将某个(2,2)窗口的每个最大值与同一位置窗口的所有其他特征图的所有其他最大值进行比较。 我想只保留所有要素图的最大值。结果再次是3个特征图(10,10),几乎所有的值都为零。
有快速的方法吗? 我不介意其他max_pooling函数,但我需要maxima的确切位置以用于池化/解放目的(但这是另一个主题)。
答案 0 :(得分:2)
我用带有cudnn的烤宽面条解决了它。以下是如何获取最大池操作(2d和3d)的索引的一些最小示例。见https://groups.google.com/forum/#!topic/lasagne-users/BhtKsRmFei4
import numpy as np
import theano
import theano.tensor as T
from theano.tensor.type import TensorType
from theano.configparser import config
import lasagne
def tensor5(name=None, dtype=None):
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, (False, False, False, False, False))
return type(name)
def max_pooling_2d():
input_var = T.tensor4('input')
input_layer = lasagne.layers.InputLayer(shape=(None, 2, 4, 4), input_var=input_var)
max_pool_layer = lasagne.layers.MaxPool2DLayer(input_layer, pool_size=(2, 2))
pool_in, pool_out = lasagne.layers.get_output([input_layer, max_pool_layer])
indices = T.grad(None, wrt=pool_in, known_grads={pool_out: T.ones_like(pool_out)})
get_indices_fn = theano.function([input_var], indices,allow_input_downcast=True)
data = np.random.randint(low=0, high=9, size=32).reshape((1,2,4,4))
indices = get_indices_fn(data)
print data, "\n\n", indices
def max_pooling_3d():
input_var = tensor5('input')
input_layer = lasagne.layers.InputLayer(shape=(1, 1, 2, 4, 4), input_var=input_var)
# 5 input dimensions: (batchsize, channels, 3 spatial dimensions)
max_pool_layer = lasagne.layers.dnn.MaxPool3DDNNLayer(input_layer, pool_size=(2, 2, 2))
pool_in, pool_out = lasagne.layers.get_output([input_layer, max_pool_layer])
indices = T.grad(None, wrt=pool_in, known_grads={pool_out: T.ones_like(pool_out)})
get_indices_fn = theano.function([input_var], indices,allow_input_downcast=True)
data = np.random.randint(low=0, high=9, size=32).reshape((1,1,2,4,4))
indices = get_indices_fn(data)
print data, "\n\n", indices