使用argmax的Keras自定义图层可消除渐变

时间:2018-11-02 18:37:47

标签: python tensorflow keras

在Keras中实现自定义层并在我的网络中使用它之后,我意识到培训期间不再更新所有先前的层。实际上,分析为我的网络计算的权重和梯度的张量板分布,自定义层之前的所有层的梯度均为零。

我要实现的层执行以下操作:

目标: 给定大小为B x H x W x D的张量,并使用D=1,向批次的每个样本中添加两个大小为H x W的新通道(这些通道的值作为常数输入给出,因此它们没有学习或动态)。对于形状为B x H x W x 3的结果张量,在样本的第一个通道上(即在二维图像上)计算 argmax ,并将结果索引应用于第二个和第三个通道。 / p>

形状: 给定B x H x W x D = (B x 8 x 8 x 1),该层将输出形状为B x 2 x 1的张量,中间张量为B x H x W x 3

数字示例: 假设B = 1H = W = 2,并让输入张量具有以下值:

input[0, :, :, 0] = [[0.1, 0.2],
                     [0.3, 0.4]]

添加新频道:

input[0, :, :, 0] = [[0.1, 0.2],
                     [0.3, 0.4]]

input[0, :, :, 1] = [[a, b],
                     [c, d]]

input[0, :, :, 2] = [[e, f],
                     [g, h]]

过滤器在第一个通道上计算 argmax ,并返回第二个和第三个通道的相应元素:

indices = argmax(input[0, :, :, 0]) = [1, 1]
input[0,indices[0], indices[1], 1] = [d]
input[0,indices[0], indices[1], 1] = [h]

Layer(input) = [
                 [[d], [h]]
               ]

在这里,我附上我为实现过滤器逻辑而编写的代码:

class CustomGlobalMaxPooling(Layer):

    def __init__(self, **kwargs):
        """Layer computing the argmax in the first channel of each sample in the input
        tensor and returning the values associated to the second and third channel.
        """

        super(CustomGlobalMaxPooling, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 4, "Only 4 dimensional input supported (B, H, W, D)"
        assert input_shape[-1] == 1, "Only supports one channels. Given shape: {}".format(input_shape)
        super(CustomGlobalMaxPooling, self).build(input_shape)

    def call(self, x, **kwargs):
        batch_size = tf.shape(x)[0]
        r = tf.cast(tf.range(batch_size), tf.int64)

        img_h = tf.shape(x)[1]
        img_w = tf.shape(x)[2]

        ## HERE WE ADD THE NEW CHANNELS ON EACH SAMPLE (code omitted) ##
        x = add_channels(x)

        flattened_values = tf.reshape(x[:, :, :, 0], (batch_size, img_h * img_w))
        flattened_1 = tf.reshape(x[:, :, :, 1], (batch_size, img_h * img_w))
        flattened_2 = tf.reshape(x[:, :, :, 2], (batch_size, img_h * img_w))

        argmax = tf.argmax(flattened_values, axis=1)
        argmax = tf.transpose(tf.stack([r, argmax]), [1, 0])

        max_1 = tf.gather_nd(flattened_1, argmax)
        max_2 = tf.gather_nd(flattened_2, argmax)

        max_1 = tf.expand_dims(max_1, -1)
        max_2 = tf.expand_dims(max_2, -1)

        return tf.transpose(tf.stack((max_1, max_2)), [1, 0, 2])

    def compute_output_shape(self, input_shape):
        batch_size = input_shape[0]
        return (batch_size, 2, 1)

我猜想,通过在一个通道上执行armgax并在新通道上应用索引,我的网络不再具有可区分性,因此不会传播梯度。

有没有办法解决这个问题?

0 个答案:

没有答案