在Keras中实现自定义层并在我的网络中使用它之后,我意识到培训期间不再更新所有先前的层。实际上,分析为我的网络计算的权重和梯度的张量板分布,自定义层之前的所有层的梯度均为零。
我要实现的层执行以下操作:
目标:
给定大小为B x H x W x D
的张量,并使用D=1
,向批次的每个样本中添加两个大小为H x W
的新通道(这些通道的值作为常数输入给出,因此它们没有学习或动态)。对于形状为B x H x W x 3
的结果张量,在样本的第一个通道上(即在二维图像上)计算 argmax ,并将结果索引应用于第二个和第三个通道。 / p>
形状:
给定B x H x W x D = (B x 8 x 8 x 1)
,该层将输出形状为B x 2 x 1
的张量,中间张量为B x H x W x 3
。
数字示例:
假设B = 1
,H = W = 2
,并让输入张量具有以下值:
input[0, :, :, 0] = [[0.1, 0.2],
[0.3, 0.4]]
添加新频道:
input[0, :, :, 0] = [[0.1, 0.2],
[0.3, 0.4]]
input[0, :, :, 1] = [[a, b],
[c, d]]
input[0, :, :, 2] = [[e, f],
[g, h]]
过滤器在第一个通道上计算 argmax ,并返回第二个和第三个通道的相应元素:
indices = argmax(input[0, :, :, 0]) = [1, 1]
input[0,indices[0], indices[1], 1] = [d]
input[0,indices[0], indices[1], 1] = [h]
Layer(input) = [
[[d], [h]]
]
在这里,我附上我为实现过滤器逻辑而编写的代码:
class CustomGlobalMaxPooling(Layer):
def __init__(self, **kwargs):
"""Layer computing the argmax in the first channel of each sample in the input
tensor and returning the values associated to the second and third channel.
"""
super(CustomGlobalMaxPooling, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 4, "Only 4 dimensional input supported (B, H, W, D)"
assert input_shape[-1] == 1, "Only supports one channels. Given shape: {}".format(input_shape)
super(CustomGlobalMaxPooling, self).build(input_shape)
def call(self, x, **kwargs):
batch_size = tf.shape(x)[0]
r = tf.cast(tf.range(batch_size), tf.int64)
img_h = tf.shape(x)[1]
img_w = tf.shape(x)[2]
## HERE WE ADD THE NEW CHANNELS ON EACH SAMPLE (code omitted) ##
x = add_channels(x)
flattened_values = tf.reshape(x[:, :, :, 0], (batch_size, img_h * img_w))
flattened_1 = tf.reshape(x[:, :, :, 1], (batch_size, img_h * img_w))
flattened_2 = tf.reshape(x[:, :, :, 2], (batch_size, img_h * img_w))
argmax = tf.argmax(flattened_values, axis=1)
argmax = tf.transpose(tf.stack([r, argmax]), [1, 0])
max_1 = tf.gather_nd(flattened_1, argmax)
max_2 = tf.gather_nd(flattened_2, argmax)
max_1 = tf.expand_dims(max_1, -1)
max_2 = tf.expand_dims(max_2, -1)
return tf.transpose(tf.stack((max_1, max_2)), [1, 0, 2])
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
return (batch_size, 2, 1)
我猜想,通过在一个通道上执行armgax并在新通道上应用索引,我的网络不再具有可区分性,因此不会传播梯度。
有没有办法解决这个问题?