TypeError:__init __()在使用自定义图层加载模型时得到了意外的关键字参数'name'

时间:2020-10-13 14:22:05

标签: python tensorflow keras tensorflow2.0 tf.keras

我在keras中创建了一个自定义层,以便在将其输入ConvLSTM2D层之前重塑CNN的输出

class TemporalReshape(Layer):
    def __init__(self,batch_size,num_patches):
        super(TemporalReshape,self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

    def call(self,inputs):
        nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
        return tf.reshape(inputs, nshape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
        return config

当我尝试使用加载最佳模型时

model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})

我得到了错误

TypeError                                 Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})


/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    180     if (h5py is not None and (
    181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
    183 
    184     filepath = path_to_string(filepath)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    176     model_config = json.loads(model_config.decode('utf-8'))
    177     model = model_config_lib.model_from_config(model_config,
--> 178                                                custom_objects=custom_objects)
    179 
    180     # set weights

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
     53                     '`Sequential.from_config(config)`?')
     54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 55   return deserialize(config, custom_objects=custom_objects)
     56 
     57 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    356             custom_objects=dict(
    357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
    360         return cls.from_config(cls_config)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
    615     """
    616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617         config, custom_objects)
    618     model = cls(inputs=input_tensors, outputs=output_tensors,
    619                 name=config.get('name'))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1202   # First, we create all layers and enqueue nodes to be processed
   1203   for layer_data in config['layers']:
-> 1204     process_layer(layer_data)
   1205   # Then we process nodes in order of layer depth.
   1206   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
   1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1185 
-> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1187       created_layers[layer_name] = layer
   1188 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
--> 360         return cls.from_config(cls_config)
    361     else:
    362       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    695         A layer instance.
    696     """
--> 697     return cls(**config)
    698 
    699   def compute_output_shape(self, input_shape):

TypeError: __init__() got an unexpected keyword argument 'name'

在构建模型时,我使用了如下的自定义层:

x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

是什么原因导致错误,以及如何在没有此错误的情况下加载模型?

2 个答案:

答案 0 :(得分:2)

仅基于错误消息,我建议将/unauthenticated/posts/1放在**kwargs中。然后,该对象将接受您未包含的任何其他关键字参数。

__init__

答案 1 :(得分:0)

在将** kwargs插入到init函数之后,我收到错误:“ TypeError: init ()缺少3个必需的位置参数:'batch_size','num_patches'”