我想从TensorFlow1中将由.index
,.meta
,.data-00000-of-00001
个文件组成的检查点加载到tensorflow2.0.0
中,并将其转换为keras模型,以便能够无需tf.Session
就可以在急切模式下本地使用它。这是我运行的代码:
import tensorflow as tf
import numpy as np
from tensorflow.python.keras.backend import set_session
from tensorflow.python.training.saver import _import_meta_graph_with_return_elements
def save_ckpt(ckpt_path='test'):
'''save TensorFlow-1 Checkpoint '''
with tf.Graph().as_default() as g:
in_op = tf.constant(np.random.rand(1,2,2,2),name='input',dtype=tf.float32)
out_op = tf.keras.layers.Conv2D(3,(3,3),padding='same',name='MY_LAYER')(in_op)
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.variables_initializer(tf.compat.v1.global_variables()))
saver.save(sess,ckpt_path)
def load_ckpt():
'''KerasModel from meta & ckpt'''
in_op = tf.keras.Input([2,2,2])
_m = tf.keras.models.Model(inputs=in_op,outputs=in_op)
with _m.input.graph.as_default() as g:
saver, out_op = _import_meta_graph_with_return_elements('test.meta',
input_map={'input':_m.output},
return_elements=[
# 'input:0',
'MY_LAYER/Conv2D:0'
])
with tf.compat.v1.Session() as sess:
saver.restore(sess,'test')
set_session(sess)
out_mdl = tf.keras.models.Model(inputs=_m.input, outputs=out_op[0])
return out_mdl
# main
save_ckpt() # save name based checkpoint
meta_model = load_ckpt() # restore in keras model
oo = meta_model(np.random.rand(1,2,2,2)) # run the model
print(oo)
但我收到此错误:
Traceback (most recent call last):
File "question2.py", line 38, in <module>
meta_model = load_ckpt() # restore in keras model
File "question2.py", line 32, in load_ckpt
out_mdl = tf.keras.models.Model(inputs=_m.input, outputs=out_op[0])
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 146, in __init__
super(Model, self).__init__(*args, **kwargs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 167, in __init__
self._init_graph_network(*args, **kwargs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 270, in _init_graph_network
base_layer_utils.create_keras_history(self._nested_outputs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py", line 184, in create_keras_history
_, created_layers = _create_keras_history_helper(tensors, set(), [])
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py", line 229, in _create_keras_history_helper
constants[i] = backend.function([], op_input)([])
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py", line 3740, in __call__
outputs = self._graph_fn(*converted_inputs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1081, in __call__
return self._call_impl(args, kwargs)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable MY_LAYER/kernel from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/MY_LAYER/kernel)
[[node MY_LAYER/Conv2D/ReadVariableOp (defined at /home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_72]
Function call stack:
keras_scratch_graph
到目前为止我尝试过的事情
在MY_LAYER/Conv2D:0
中用input:0
替换_import_meta_graph_with_return_elements()
可使代码正常运行。