使用tf.keras.models.save_model()保存多输入TF 2.x子类模型时发生TypeError

时间:2019-12-19 22:02:58

标签: python tensorflow keras tensorflow2.0

this tutorial的训练过程之后,我尝试使用以下代码保存tf模型:

tf.keras.models.save_model(decoder, 'path', save_format='tf')

但出现错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-33-5a5dd79f753f> in <module>()
----> 1 tf.keras.models.save_model(decoder, './drive/My Drive/DeepLearning/decoder/kerasencoder', save_format='tf')

20 frames
/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115                           signatures, options)
    116 
    117 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
     76     # we use the default replica context here.
     77     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
---> 78       save_lib.save(model, filepath, signatures, options)
     79 
     80   if not include_optimizer:

/tensorflow-2.1.0/python3.6/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    884   if signatures is None:
    885     signatures = signature_serialization.find_function_to_export(
--> 886         checkpoint_graph_view)
    887 
    888   signatures = signature_serialization.canonicalize_signatures(signatures)

/tensorflow-2.1.0/python3.6/tensorflow_core/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     72   # If the user did not specify signatures, check the root object for a function
     73   # that can be made into a signature.
---> 74   functions = saveable_view.list_functions(saveable_view.root)
     75   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
     76   if signature is not None:

/tensorflow-2.1.0/python3.6/tensorflow_core/python/saved_model/save.py in list_functions(self, obj)
    140     if obj_functions is None:
    141       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
--> 142           self._serialization_cache)
    143       self._functions[obj] = obj_functions
    144     return obj_functions

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   2418   def _list_functions_for_serialization(self, serialization_cache):
   2419     return (self._trackable_saved_model_saver
-> 2420             .list_functions_for_serialization(serialization_cache))
   2421 
   2422   def __getstate__(self):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     89         `ConcreteFunction`.
     90     """
---> 91     fns = self.functions_to_serialize(serialization_cache)
     92 
     93     # The parent AutoTrackable class saves all user-defined tf.functions, and

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     78   def functions_to_serialize(self, serialization_cache):
     79     return (self._get_serialized_attributes(
---> 80         serialization_cache).functions_to_serialize)
     81 
     82   def _get_serialized_attributes(self, serialization_cache):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     93 
     94     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95         serialization_cache)
     96 
     97     serialized_attr.set_and_validate_objects(object_dict)

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     45     # cache (i.e. this is the root level object).
     46     if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 47       default_signature = save_impl.default_save_signature(self.obj)
     48 
     49     # Other than the default signature function, all other attributes match with

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
    210   original_losses = _reset_layer_losses(layer)
    211   fn = saving_utils.trace_model_call(layer)
--> 212   fn.get_concrete_function()
    213   _restore_layer_losses(original_losses)
    214   return fn

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
    907       if self._stateful_fn is None:
    908         initializers = []
--> 909         self._initialize(args, kwargs, add_initializers_to=initializers)
    910         self._initialize_uninitialized_variables(initializers)
    911 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    495     self._concrete_stateful_fn = (
    496         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 497             *args, **kwds))
    498 
    499     def invalid_creator_scope(*unused_args, **unused_kwds):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2387       args, kwargs = None, None
   2388     with self._lock:
-> 2389       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2390     return graph_function
   2391 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

/tensorflow-2.1.0/python3.6/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/saving/saving_utils.py in _wrapped_model(*args)
    148     with base_layer_utils.call_context().enter(
    149         model, inputs=inputs, build_graph=False, training=False, saving=True):
--> 150       outputs_list = nest.flatten(model(inputs=inputs, training=False))
    151 
    152     try:

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    776                     outputs = base_layer_utils.mark_as_return(outputs, acd)
    777                 else:
--> 778                   outputs = call_fn(cast_inputs, *args, **kwargs)
    779 
    780             except errors.OperatorNotAllowedInGraphError as e:

/tensorflow-2.1.0/python3.6/tensorflow_core/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    290   def wrapper(*args, **kwargs):
    291     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 292       return func(*args, **kwargs)
    293 
    294   if inspect.isfunction(func) or inspect.ismethod(func):

TypeError: call() missing 2 required positional arguments: 'hidden' and 'enc_output'

在链接中,解码器定义为:

class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # used for attention
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output):
    # enc_output shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

如何解决此错误?

1 个答案:

答案 0 :(得分:2)

由于序列化问题,可能会发生错误。

有多种方法可以保存Tensorflow模型。在本教程中,您提供了他们使用tf.Checkpoint的情况,主要是因为无法通过tf.keras.Modeltf.keras.model.save_model安全地序列化子类model.save

看看docs(尤其是自定义对象部分)和this issue

使用子类化API时,通过model.save_weights仅保存模型权重更为安全。

希望这会有所帮助。