我在keras中创建了一个自定义层,以便在将其输入ConvLSTM2D层之前重塑CNN的输出
class TemporalReshape(Layer):
def __init__(self,batch_size,num_patches):
super(TemporalReshape,self).__init__()
self.batch_size = batch_size
self.num_patches = num_patches
def call(self,inputs):
nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
return tf.reshape(inputs, nshape)
def get_config(self):
config = super().get_config().copy()
config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
return config
当我尝试使用加载最佳模型时
model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
我得到了错误
TypeError Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
180 if (h5py is not None and (
181 isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182 return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
183
184 filepath = path_to_string(filepath)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
176 model_config = json.loads(model_config.decode('utf-8'))
177 model = model_config_lib.model_from_config(model_config,
--> 178 custom_objects=custom_objects)
179
180 # set weights
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
53 '`Sequential.from_config(config)`?')
54 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 55 return deserialize(config, custom_objects=custom_objects)
56
57
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
173 module_objects=LOCAL.ALL_OBJECTS,
174 custom_objects=custom_objects,
--> 175 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
356 custom_objects=dict(
357 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358 list(custom_objects.items())))
359 with CustomObjectScope(custom_objects):
360 return cls.from_config(cls_config)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
615 """
616 input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617 config, custom_objects)
618 model = cls(inputs=input_tensors, outputs=output_tensors,
619 name=config.get('name'))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1202 # First, we create all layers and enqueue nodes to be processed
1203 for layer_data in config['layers']:
-> 1204 process_layer(layer_data)
1205 # Then we process nodes in order of layer depth.
1206 # Nodes that cannot yet be processed (if the inbound node
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
1184 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
1185
-> 1186 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1187 created_layers[layer_name] = layer
1188
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
173 module_objects=LOCAL.ALL_OBJECTS,
174 custom_objects=custom_objects,
--> 175 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
358 list(custom_objects.items())))
359 with CustomObjectScope(custom_objects):
--> 360 return cls.from_config(cls_config)
361 else:
362 # Then `cls` may be a function returning a class.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
695 A layer instance.
696 """
--> 697 return cls(**config)
698
699 def compute_output_shape(self, input_shape):
TypeError: __init__() got an unexpected keyword argument 'name'
在构建模型时,我使用了如下的自定义层:
x = TemporalReshape(batch_size = 8, num_patches = 16)(x)
是什么原因导致错误,以及如何在没有此错误的情况下加载模型?
答案 0 :(得分:2)
仅基于错误消息,我建议将/unauthenticated/posts/1
放在**kwargs
中。然后,该对象将接受您未包含的任何其他关键字参数。
__init__
答案 1 :(得分:0)
在将** kwargs插入到init函数之后,我收到错误:“ TypeError: init ()缺少3个必需的位置参数:'batch_size','num_patches'”