来自自定义keras层的tflite转换器

时间:2019-09-10 07:10:19

标签: tensorflow keras-layer tensorflow-lite tf.keras

尝试将keras .h5文件转换为tflite时出现TypeError。 新层是高斯核(径向基础层)。
为了能够保存和加载keras模型,我还在自定义层中定义了get_config()方法。这样我就可以正确保存和加载模型。

class RBFLayer(Layer):
    def __init__(self, output_dim, centers=None, tol = 1e-6, gamma=0, **kwargs):
        super(RBFLayer, self).__init__(**kwargs)
        self.centers_ = centers
        self.output_dim= output_dim
        self.gamma_ = gamma
        self.tol_ = tol

    def build(self, input_shape):
        self.mu = K.variable(self.centers_, name='centers')
        self.gamma = K.variable(self.gamma_, name='gamma')
        self.tol = K.constant(self.tol_,name='tol')            
        super(RBFLayer, self).build(input_shape)
    def call(self, inputs): #Kernel radial
        a,b = self.mu.shape
        diff = K.reshape( K.tile(inputs,(1,a))-K.reshape(self.mu,(1,-1)), (-1,a,b))
        l2 =   K.sum(K.pow(diff, 2), axis=-1)
        res =  K.exp(-1 * self.gamma * l2)
        mask = K.greater( res, self.tol)
        return K.switch(mask, res, K.zeros_like(res)) 
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    def get_config(self):
        config = {
        'output_dim': self.output_dim,
        'centers': self.centers_,
        'gamma': self.gamma_,
        'tol': self.tol_
        }
        base_config = super(RBFLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

现在,我要将模型保存到tflite。我从keras文件中使用了TFLiteConverter,其中还包括“ custom_objects”。

def save_tflite(self, base_name):
        file =base_name +'.h5'
        converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects={'RBFLayer':RBFLayer})
        tflite_model = converter.convert()
        open(base_name+".tflite", "wb").write(tflite_model)

我希望获得tflite模型文件,其中包括在训练整个模型(中心,tol和gamma)时使用的K.variables。

转换时,我收到以下错误消息:

airgorbn.save_tflite( base_name)
Traceback (most recent call last):

  File "<ipython-input-7-cdaa1ec46233>", line 1, in <module>
    airgorbn.save_tflite( base_name)

  File "C:/Users/AIRFI/Hospital/keras_RadialBasis.py", line 158, in save_tflite
    converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects={'RBFLayer':RBFLayer})

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\lite\python\lite.py", line 747, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file, custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\save.py", line 146, in load_model
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\hdf5_format.py", line 212, in load_model_from_hdf5
    custom_objects=custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\model_config.py", line 55, in model_from_config
    return deserialize(config, custom_objects=custom_objects)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 89, in deserialize
    printable_module_name='layer')

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 192, in deserialize_keras_object
    list(custom_objects.items())))

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 353, in from_config
    model.add(layer)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\training\tracking\base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 154, in add
    'Found: ' + str(layer))

TypeError: The added layer must be an instance of class Layer. Found: <__main__.RBFLayer object at 0x0000017D3A75AC50>

1 个答案:

答案 0 :(得分:0)

您需要将该层定义为自定义操作。

请参阅此https://www.tensorflow.org/lite/guide/ops_custom