加载tensorflow模型.pb文件和文件夹

时间:2020-04-13 17:34:05

标签: python tensorflow2.0

保存模型后,它在文件夹名称“ MODEL_X”中包含“ assets,saved_model.pb,variables”。

现在当我在tf.keras.models.load_model(file_path)中设置路径时。

出现错误。

我该如何解决?

<ipython-input-15-8838bb61f3d3> in <module>
----> 1 new_model = tf.keras.models.load_model(file_path)

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
    148   if isinstance(filepath, six.string_types):
    149     loader_impl.parse_saved_model(filepath)
--> 150     return saved_model_load.load(filepath, compile)
    151 
    152   raise IOError(

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/load.py in load(path, compile)
     87   # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
     88   # TODO(kathywu): Add code to load from objects that contain all endpoints
---> 89   model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
     90 
     91   # pylint: disable=protected-access

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in load_internal(export_dir, tags, loader_cls)
    550       loader = loader_cls(object_graph_proto,
    551                           saved_model_proto,
--> 552                           export_dir)
    553       root = loader.get(0)
    554     root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/load.py in __init__(self, *args, **kwargs)
    116 
    117   def __init__(self, *args, **kwargs):
--> 118     super(KerasObjectLoader, self).__init__(*args, **kwargs)
    119     self._finalize()
    120 

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir)
    119       self._concrete_functions[name] = _WrapperFunction(concrete_function)
    120 
--> 121     self._load_all()
    122     # TODO(b/124045874): There are limitations with functions whose captures
    123     # trigger other functions to be executed. For now it is only guaranteed to

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in _load_all(self)
    237         # interface.
    238         continue
--> 239       node, setter = self._recreate(proto)
    240       nodes[node_id] = node
    241       node_setters[node_id] = setter

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in _recreate(self, proto)
    320     if kind not in factory:
    321       raise ValueError("Unknown SavedObject type: %r" % kind)
--> 322     return factory[kind]()
    323 
    324   def _recreate_user_object(self, proto):

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in <lambda>()
    307     """Creates a Python object from a SavedObject protocol buffer."""
    308     factory = {
--> 309         "user_object": lambda: self._recreate_user_object(proto.user_object),
    310         "asset": lambda: self._recreate_asset(proto.asset),
    311         "function": lambda: self._recreate_function(proto.function),

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/load.py in _recreate_user_object(self, proto)
    326     looked_up = revived_types.deserialize(proto)
    327     if looked_up is None:
--> 328       return self._recreate_base_user_object(proto)
    329     return looked_up
    330 

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/load.py in _recreate_base_user_object(self, proto)
    214           parent_classes,
    215           {'__setattr__': parent_classes[1].__setattr__})
--> 216       return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
    217 
    218     return super(KerasObjectLoader, self)._recreate_base_user_object(proto)

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/load.py in _init_from_metadata(cls, metadata)
    295         ragged=metadata['ragged'],
    296         batch_input_shape=metadata['batch_input_shape'])
--> 297     revived_obj = cls(**init_args)
    298     with trackable.no_automatic_dependency_tracking_scope(revived_obj):
    299       revived_obj._config = metadata['config']  # pylint:disable=protected-access

~/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/input_layer.py in __init__(self, input_shape, batch_size, dtype, input_tensor, sparse, name, ragged, **kwargs)
     84                          'batch_input_shape argument to '
     85                          'InputLayer, not both at the same time.')
---> 86       batch_size = batch_input_shape[0]
     87       input_shape = batch_input_shape[1:]
     88     if kwargs:

KeyError: 0`

1 个答案:

答案 0 :(得分:0)

在Anaconda3中使用Tensorflow 2.1.1获得相同的错误。通过Pip删除Tensorflow 2.1.1并重新安装Tensforflow 2.2.0后,错误消失了。

检查版本

import tensorflow
print(tensorflow.__version__)

为用户和系统卸载

/usr/local/anaconda/bin/pip uninstall tensorflow
sudo /usr/local/anaconda/bin/pip uninstall tensorflow

重新安装

sudo /usr/local/anaconda/bin/pip install --ignore-installed  --upgrade tensorflow --no-cache-dir

在使用Arch Linux的PC上,升级到2.2.0后出现以下问题:

AttributeError: module 'tensorflow.python.keras.utils.generic_utils' has no attribute 'populate_dict_with_module_objects'

已通过每晚使用tf修复