我目前正在使用simontomaskarlssons github存储库作为基准来研究CycleGAN和im。我的问题出在训练完成后,我想使用保存的模型来生成新样本。此处,已加载模型的模型架构与初始化的生成器不同。 saveModel函数的直接链接是here。
当我初始化将域A转换为B的生成器时,摘要如下所示(line in github)。这是预期的,因为我的输入图像为(140,140,1),而我的输出图像为(140,140,1):
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_5 (InputLayer) (None, 140, 140, 1) 0
__________________________________________________________________________________________________
reflection_padding2d_1 (Reflect (None, 146, 146, 1) 0 input_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 140, 140, 32) 1600 reflection_padding2d_1[0][0]
__________________________________________________________________________________________________
instance_normalization_5 (Insta (None, 140, 140, 32) 64 conv2d_9[0][0]
__________________________________________________________________________________________________
...
__________________________________________________________________________________________________
activation_12 (Activation) (None, 140, 140, 32) 0 instance_normalization_23[0][0]
__________________________________________________________________________________________________
reflection_padding2d_16 (Reflec (None, 146, 146, 32) 0 activation_12[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 140, 140, 1) 1569 reflection_padding2d_16[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 140, 140, 1) 0 conv2d_26[0][0]
==================================================================================================
Total params: 2,258,177
Trainable params: 2,258,177
Non-trainable params: 0
训练完成后,我想加载保存的模型以生成新样本(从域A到域B的翻译)。在这种情况下,模型是否成功翻译图像都没有关系。我使用以下代码加载模型:
# load json and create model
json_file = open('G_A2B_model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json, custom_objects={'ReflectionPadding2D': ReflectionPadding2D, 'InstanceNormalization': InstanceNormalization})
或以下给出相同结果的
loaded_model = load_model('G_A2B_model.h5', custom_objects={'ReflectionPadding2D': ReflectionPadding2D, 'InstanceNormalization': InstanceNormalization})
ReflectionPadding2D初始化为的位置(请注意,我有一个单独的文件用于加载模型,然后用于训练CycleGAN):
# reflection padding taken from
# https://github.com/fastai/courses/blob/master/deeplearning2/neural-style.ipynb
class ReflectionPadding2D(Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, s):
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad, h_pad = self.padding
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
现在已加载我的模型,我想将图像从域A转换为域B。在这里,我期望输出形状为(140,140,1),但令人惊讶的是(132,132,1)。我检查了G_A2B_model的体系结构摘要,该摘要清楚表明输出的形状为(132,132,1):
Model: "G_A2B_model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_5 (InputLayer) (None, 140, 140, 1) 0
__________________________________________________________________________________________________
reflection_padding2d_1 (Reflect (None, 142, 142, 1) 0 input_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 136, 136, 32) 1600 reflection_padding2d_1[0][0]
__________________________________________________________________________________________________
instance_normalization_5 (Insta (None, 136, 136, 32) 64 conv2d_9[0][0]
__________________________________________________________________________________________________
...
__________________________________________________________________________________________________
instance_normalization_23 (Inst (None, 136, 136, 32) 64 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 136, 136, 32) 0 instance_normalization_23[0][0]
__________________________________________________________________________________________________
reflection_padding2d_16 (Reflec (None, 138, 138, 32) 0 activation_12[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 132, 132, 1) 1569 reflection_padding2d_16[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 132, 132, 1) 0 conv2d_26[0][0]
==================================================================================================
Total params: 2,258,177
Trainable params: 2,258,177
Non-trainable params: 0
我不明白的是为什么输出形状为(132x132x1)。但是我可以看到,在reflectionPadding2D中出现了问题,其中初始化生成器的输出形状为(146,146,1),保存生成器的输出形状为(142,142,1)。但是我不知道为什么会这样吗?因为理论上它们应该大小相同。
答案 0 :(得分:1)
使用model.to_json
保留体系结构时,将调用方法get_config
,以便同时保存图层属性。当您使用不带该方法的自定义类时,调用model_from_json
时将使用填充的默认值。
对ReflectionPadding2D
使用以下代码可以解决您的问题,只需再次运行训练步骤并重新加载模型即可。
class ReflectionPadding2D(Layer):
def __init__(self, padding=(1,1), **kwargs):
self.padding = tuple(padding)
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, s):
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad, h_pad = self.padding
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
# This is the relevant method that should be added
def get_config(self):
config = {
'padding': self.padding
}
base_config = super(ReflectionPadding2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))