TFLite自定义对象:ValueError:未知层:ReflectionPadding2D

时间:2019-10-29 10:38:17

标签: python tensorflow keras keras-layer tf.keras

我正在研究https://github.com/RaphaelMeudec/deblur-gan/,以改善他们的DeblurGAN。我的目标是将DeblurGAN训练的H5模型转换为TFLite格式。

在我的H5模型中,我定义了一个自定义层,名称为ReflectionPadding2D(下面给出的代码)。我使用以下Python命令进行此转换:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
model = tf.keras.models.load_model(
    os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)))
converter = tf.lite.TFLiteConverter.from_keras_model_file(model, custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
     "wb").write(tflite_model)

如您所见,我使用custom_objectsReflectionPadding2D只是由于from deblurgan.layer_utils import ReflectionPadding2D而导入的类(不是对象)。

由于我的模型包含我的自定义层ReflectionPadding2D,因此上述命令输出以下错误:

  

ValueError:未知层:ReflectionPadding2D

保存我的H5模型的代码

也许我应该添加一行以在保存的H5中包含一些内容,以允许其转换为TFLite格式?这是我用来保存H5模型的代码:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)

ReflectionPadding2D的代码(请查看方法call

def spatial_reflection_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
    """
    Pad the 2nd and 3rd dimensions of a 4D tensor.

    :param x: Input tensor
    :param padding: Shape of padding to use
    :param data_format: Tensorflow vs Theano convention ('channels_last', 'channels_first')
    :return: Tensorflow tensor
    """
    assert len(padding) == 2
    assert len(padding[0]) == 2
    assert len(padding[1]) == 2
    if data_format is None:
        data_format = image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    if data_format == 'channels_first':
        pattern = [[0, 0],
                   [0, 0],
                   list(padding[0]),
                   list(padding[1])]
    else:
        pattern = [[0, 0],
                   list(padding[0]), list(padding[1]),
                   [0, 0]]
    return tf.pad(x, pattern, "REFLECT")


class ReflectionPadding2D(Layer):

    def __init__(self,
                 padding=(1, 1),
                 data_format=None,
                 **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.data_format = conv_utils.normalize_data_format(data_format)
        if isinstance(padding, int):
            self.padding = ((padding, padding), (padding, padding))
        elif hasattr(padding, '__len__'):
            if len(padding) != 2:
                raise ValueError('`padding` should have two elements. '
                                 'Found: ' + str(padding))
            height_padding = conv_utils.normalize_tuple(padding[0], 2,
                                                        '1st entry of padding')
            width_padding = conv_utils.normalize_tuple(padding[1], 2,
                                                       '2nd entry of padding')
            self.padding = (height_padding, width_padding)
        else:
            raise ValueError('`padding` should be either an int, '
                             'a tuple of 2 ints '
                             '(symmetric_height_pad, symmetric_width_pad), '
                             'or a tuple of 2 tuples of 2 ints '
                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
                             'Found: ' + str(padding))
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            if input_shape[2] is not None:
                rows = input_shape[2] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[3] is not None:
                cols = input_shape[3] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    input_shape[1],
                    rows,
                    cols)
        elif self.data_format == 'channels_last':
            if input_shape[1] is not None:
                rows = input_shape[1] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[2] is not None:
                cols = input_shape[2] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    rows,
                    cols,
                    input_shape[3])

    def call(self, inputs):
        return spatial_reflection_2d_padding(inputs,
                                             padding=self.padding,
                                             data_format=self.data_format)

    def get_config(self):
        config = {'padding': self.padding,
                  'data_format': self.data_format}
        base_config = super(ReflectionPadding2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

最终问题

要允许转换为TFLite格式,我应该更改什么?

1 个答案:

答案 0 :(得分:0)

此代码有效:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)

model = tf.keras.models.load_model(
    os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
     "wb").write(tflite_model)