是否有人为keras撰写了weldon汇集?

时间:2018-04-23 19:15:59

标签: keras max-pooling

Weldon汇集[1]是否已在Keras实施?

我可以看到它已经由作者[2]在pytorch中实现,但是找不到类似的keras。

[1] T. Durand,N。Thome和M. Cord。韦尔登:弱势 深度卷积神经网络的有限学习。在 CVPR,2016年。 [2] https://github.com/durandtibo/weldon.resnet.pytorch/tree/master/weldon

1 个答案:

答案 0 :(得分:1)

这是一个基于lua版本(有一个pytorch impl,但我认为有一个错误取平均最大+分钟)。我假设lua版本的顶部最大值和最小值的平均值仍然是正确的。我没有测试过整个自定义图层方面,但已经足够接近以获得一些信息,欢迎评论。

class WeldonPooling(Layer):
    """Class to implement Weldon selective spacial pooling with negative evidence
    """

    #@interfaces.legacy_global_pooling_support
    def __init__(self, kmax, kmin=-1, data_format=None, **kwargs):
        super(WeldonPooling, self).__init__(**kwargs)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.input_spec = InputSpec(ndim=4)
        self.kmax=kmax
        self.kmin=kmin

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            return (input_shape[0], input_shape[3])
        else:
            return (input_shape[0], input_shape[1])

    def get_config(self):
        config = {'data_format': self.data_format}
        base_config = super(_GlobalPooling2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs):
        if self.data_format == "channels_last":
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
        kmax=self.kmax
        kmin=self.kmin
        shape=tf.shape(inputs)
        batch_size = shape[0]
        num_channels = shape[1]
        h = shape[2]
        w = shape[3]
        n = h * w
        view = tf.reshape(inputs, [batch_size, num_channels, n])
        sorted, indices = tf.nn.top_k(view, n, sorted=True)
        #indices_max = tf.slice(indices,[0,0,0],[batch_size, num_channels, kmax])
        output = tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,0],[batch_size, num_channels, kmax]),2),kmax)

        if kmin > 0:
            #indices_min = tf.slice(indices,[0,0, n-kmin],[batch_size, num_channels, kmin])
            output=tf.add(output,tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,n-kmin],[batch_size, num_channels, kmin]),2),kmin))

        return tf.reshape(output,[batch_size, num_channels])