我正在研究https://github.com/RaphaelMeudec/deblur-gan/,以改善他们的DeblurGAN。我的目标是将DeblurGAN训练的H5模型转换为TFLite格式。
在我的H5模型中,我定义了一个自定义层,名称为ReflectionPadding2D
(下面给出的代码)。我使用以下Python命令进行此转换:
g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
model = tf.keras.models.load_model(
os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)))
converter = tf.lite.TFLiteConverter.from_keras_model_file(model, custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
"wb").write(tflite_model)
如您所见,我使用custom_objects
。 ReflectionPadding2D
只是由于from deblurgan.layer_utils import ReflectionPadding2D
而导入的类(不是对象)。
由于我的模型包含我的自定义层ReflectionPadding2D
,因此上述命令输出以下错误:
ValueError:未知层:ReflectionPadding2D
也许我应该添加一行以在保存的H5中包含一些内容,以允许其转换为TFLite格式?这是我用来保存H5模型的代码:
g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
ReflectionPadding2D
的代码(请查看方法call
)def spatial_reflection_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""
Pad the 2nd and 3rd dimensions of a 4D tensor.
:param x: Input tensor
:param padding: Shape of padding to use
:param data_format: Tensorflow vs Theano convention ('channels_last', 'channels_first')
:return: Tensorflow tensor
"""
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if data_format == 'channels_first':
pattern = [[0, 0],
[0, 0],
list(padding[0]),
list(padding[1])]
else:
pattern = [[0, 0],
list(padding[0]), list(padding[1]),
[0, 0]]
return tf.pad(x, pattern, "REFLECT")
class ReflectionPadding2D(Layer):
def __init__(self,
padding=(1, 1),
data_format=None,
**kwargs):
super(ReflectionPadding2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding))
elif hasattr(padding, '__len__'):
if len(padding) != 2:
raise ValueError('`padding` should have two elements. '
'Found: ' + str(padding))
height_padding = conv_utils.normalize_tuple(padding[0], 2,
'1st entry of padding')
width_padding = conv_utils.normalize_tuple(padding[1], 2,
'2nd entry of padding')
self.padding = (height_padding, width_padding)
else:
raise ValueError('`padding` should be either an int, '
'a tuple of 2 ints '
'(symmetric_height_pad, symmetric_width_pad), '
'or a tuple of 2 tuples of 2 ints '
'((top_pad, bottom_pad), (left_pad, right_pad)). '
'Found: ' + str(padding))
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
if input_shape[2] is not None:
rows = input_shape[2] + self.padding[0][0] + self.padding[0][1]
else:
rows = None
if input_shape[3] is not None:
cols = input_shape[3] + self.padding[1][0] + self.padding[1][1]
else:
cols = None
return (input_shape[0],
input_shape[1],
rows,
cols)
elif self.data_format == 'channels_last':
if input_shape[1] is not None:
rows = input_shape[1] + self.padding[0][0] + self.padding[0][1]
else:
rows = None
if input_shape[2] is not None:
cols = input_shape[2] + self.padding[1][0] + self.padding[1][1]
else:
cols = None
return (input_shape[0],
rows,
cols,
input_shape[3])
def call(self, inputs):
return spatial_reflection_2d_padding(inputs,
padding=self.padding,
data_format=self.data_format)
def get_config(self):
config = {'padding': self.padding,
'data_format': self.data_format}
base_config = super(ReflectionPadding2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
要允许转换为TFLite格式,我应该更改什么?
答案 0 :(得分:0)
此代码有效:
g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
model = tf.keras.models.load_model(
os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
"wb").write(tflite_model)