tf.keras.estimator.model_to_estimator无法转换具有自定义层和Lambda层的keras模型

时间:2019-10-25 20:03:39

标签: tensorflow keras python-3.6 tensorflow-estimator tf.keras

我前一段时间编写了一个模型,该模型使用了一些自定义图层定义,并已使用TF 1.12和独立的Keras 2.2.4进行了训练。我已经将TF的版本更新为1.14,并切换到tf.keras。使用自定义加载函数,我的模型可以构建,加载权重并生成预测。

现在,我正在尝试将我的keras模型转换为TF Estimator(我可以将其用于推理),并且遇到了各种各样的问题。我相信它源于我的Lambda层中的get_config()方法。我目前对它们的定义如下:

class NamedLambda(Lambda):
    def __init__(self, name=None):
        Lambda.__init__(self, self.fn, name=name)

    @classmethod
    def invoke(cls, args, **kw):
        return cls(**kw)(args)

    def __repr__(self):
        return '%s(%s)' % (self.__class__.__name__, self.name)

class L2Normalize(NamedLambda):
    def fn(self, x):
        return K.l2_normalize(x, axis=-1)

当我检查时,get_config方法可以正常工作:

custom_objects['l2_normalize'].get_config()
{'arguments': DictWrapper({}),
 'dtype': 'float32',
 'function': 'fn',
 'function_type': 'function',
 'module': 'grademachine.utils',
 'name': 'l2_normalize',
 'output_shape': None,
 'output_shape_module': None,
 'output_shape_type': 'raw',
 'trainable': True}

下面是一些示例代码和让我感到困惑的回溯。任何帮助将不胜感激。

  • Python版本:3.6.2
  • TensorFlow版本:1.14.0
  • Keras版本:2.2.4-tf
model = load_model(model_dir, 
                   options_fn='model123_options', 
                   weights_fn='model123_weights')
model
<tensorflow.python.keras.engine.training.Model at 0x7fe3d43d8e10>
est = tf.keras.estimator.model_to_estimator(keras_model=model)

我还尝试过如下添加我的自定义层,这会产生稍有不同的回溯,但最终会出现在同一位置。下面的回溯来自定义了custom_objects的版本:

# custom_layer_names is a list of names of each of the custom layers in the trained model
custom_objects = {l.name: l for l in model.layers if l.name in custom_layer_names}
est = tf.keras.estimator.model_to_estimator(keras_model=model,  
                                            custom_objects=custom_objects)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpyujm6s99
INFO:tensorflow:Using the Keras model provided.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-13-512a382c338c> in <module>()
     13 est = tf.keras.estimator.model_to_estimator(keras_model=model, 
     14                                             model_dir='saved_estimator/',
---> 15                                             custom_objects=custom_objects)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/estimator/__init__.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
     71       custom_objects=custom_objects,
     72       model_dir=model_dir,
---> 73       config=config)
     74 
     75 # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
    448   if keras_model._is_graph_network:
    449     warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
--> 450                                              config)
    451   elif keras_model.built:
    452     logging.warning('You are creating an Estimator from a Keras model manually '

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _save_first_checkpoint(keras_model, custom_objects, config)
    316       training_util.create_global_step()
    317       model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,
--> 318                                      custom_objects)
    319       # save to checkpoint
    320       with session.Session(config=config.session_config) as sess:

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _clone_and_build_model(mode, keras_model, custom_objects, features, labels)
    199       compile_clone=compile_clone,
    200       in_place_reset=(not keras_model._is_graph_network),
--> 201       optimizer_iterations=global_step)
    202 
    203   return clone

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_and_build_model(model, input_tensors, target_tensors, custom_objects, compile_clone, in_place_reset, optimizer_iterations, optimizer_config)
    534     if custom_objects:
    535       with CustomObjectScope(custom_objects):
--> 536         clone = clone_model(model, input_tensors=input_tensors)
    537     else:
    538       clone = clone_model(model, input_tensors=input_tensors)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors, clone_function)
    324   else:
    325     return _clone_functional_model(
--> 326         model, input_tensors=input_tensors, layer_fn=clone_function)
    327 
    328 

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors, layer_fn)
    152       # Get or create layer.
    153       if layer not in layer_map:
--> 154         new_layer = layer_fn(layer)
    155         layer_map[layer] = new_layer
    156         layer = new_layer

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_layer(layer)
     52 
     53 def _clone_layer(layer):
---> 54   return layer.__class__.from_config(layer.get_config())
     55 
     56 

~/repos/grademachine/grademachine/utils.py in from_config(cls, config, custom_objects)
    850     config = config.copy()
    851     function = cls._parse_function_from_config(
--> 852         config, custom_objects, 'function', 'module', 'function_type')
    853 
    854     output_shape = cls._parse_function_from_config(

~/repos/grademachine/grademachine/utils.py in _parse_function_from_config(cls, config, custom_objects, func_attr_name, module_attr_name, func_type_attr_name)
    898           config[func_attr_name],
    899           custom_objects=custom_objects,
--> 900           printable_module_name='function in Lambda layer')
    901     elif function_type == 'lambda':
    902       # Unsafe deserialization from bytecode

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    207       obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
    208     else:
--> 209       obj = module_objects.get(object_name)
    210       if obj is None:
    211         raise ValueError('Unknown ' + printable_module_name + ':' + object_name)

AttributeError: 'NoneType' object has no attribute 'get'

0 个答案:

没有答案