尝试将keras .h5文件转换为tflite时出现TypeError。
新层是高斯核(径向基础层)。
为了能够保存和加载keras模型,我还在自定义层中定义了get_config()方法。这样我就可以正确保存和加载模型。
class RBFLayer(Layer):
def __init__(self, output_dim, centers=None, tol = 1e-6, gamma=0, **kwargs):
super(RBFLayer, self).__init__(**kwargs)
self.centers_ = centers
self.output_dim= output_dim
self.gamma_ = gamma
self.tol_ = tol
def build(self, input_shape):
self.mu = K.variable(self.centers_, name='centers')
self.gamma = K.variable(self.gamma_, name='gamma')
self.tol = K.constant(self.tol_,name='tol')
super(RBFLayer, self).build(input_shape)
def call(self, inputs): #Kernel radial
a,b = self.mu.shape
diff = K.reshape( K.tile(inputs,(1,a))-K.reshape(self.mu,(1,-1)), (-1,a,b))
l2 = K.sum(K.pow(diff, 2), axis=-1)
res = K.exp(-1 * self.gamma * l2)
mask = K.greater( res, self.tol)
return K.switch(mask, res, K.zeros_like(res))
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
def get_config(self):
config = {
'output_dim': self.output_dim,
'centers': self.centers_,
'gamma': self.gamma_,
'tol': self.tol_
}
base_config = super(RBFLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
现在,我要将模型保存到tflite。我从keras文件中使用了TFLiteConverter,其中还包括“ custom_objects”。
def save_tflite(self, base_name):
file =base_name +'.h5'
converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects={'RBFLayer':RBFLayer})
tflite_model = converter.convert()
open(base_name+".tflite", "wb").write(tflite_model)
我希望获得tflite模型文件,其中包括在训练整个模型(中心,tol和gamma)时使用的K.variables。
转换时,我收到以下错误消息:
airgorbn.save_tflite( base_name)
Traceback (most recent call last):
File "<ipython-input-7-cdaa1ec46233>", line 1, in <module>
airgorbn.save_tflite( base_name)
File "C:/Users/AIRFI/Hospital/keras_RadialBasis.py", line 158, in save_tflite
converter = tf.lite.TFLiteConverter.from_keras_model_file(file, custom_objects={'RBFLayer':RBFLayer})
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\lite\python\lite.py", line 747, in from_keras_model_file
keras_model = _keras.models.load_model(model_file, custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\save.py", line 146, in load_model
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\hdf5_format.py", line 212, in load_model_from_hdf5
custom_objects=custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\saving\model_config.py", line 55, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 89, in deserialize
printable_module_name='layer')
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 192, in deserialize_keras_object
list(custom_objects.items())))
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 353, in from_config
model.add(layer)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\training\tracking\base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\sequential.py", line 154, in add
'Found: ' + str(layer))
TypeError: The added layer must be an instance of class Layer. Found: <__main__.RBFLayer object at 0x0000017D3A75AC50>