在TensorFlow2中加载TensorFlow1检查点-“函数调用堆栈:keras_scratch_graph”

时间:2019-10-01 07:58:31

标签: python tensorflow keras tensorflow2.0

我想从TensorFlow1中将由.index.meta.data-00000-of-00001个文件组成的检查点加载到tensorflow2.0.0中,并将其转换为keras模型,以便能够无需tf.Session就可以在急切模式下本地使用它。这是我运行的代码:

import tensorflow as tf
import numpy as np
from tensorflow.python.keras.backend import set_session
from tensorflow.python.training.saver import _import_meta_graph_with_return_elements


def save_ckpt(ckpt_path='test'):
  '''save TensorFlow-1 Checkpoint '''
  with tf.Graph().as_default() as g:
    in_op  = tf.constant(np.random.rand(1,2,2,2),name='input',dtype=tf.float32)
    out_op = tf.keras.layers.Conv2D(3,(3,3),padding='same',name='MY_LAYER')(in_op)
    saver = tf.compat.v1.train.Saver()
    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.variables_initializer(tf.compat.v1.global_variables()))
        saver.save(sess,ckpt_path)


def load_ckpt():
    '''KerasModel from meta & ckpt'''
    in_op = tf.keras.Input([2,2,2])
    _m = tf.keras.models.Model(inputs=in_op,outputs=in_op)    
    with _m.input.graph.as_default() as g:
        saver, out_op =  _import_meta_graph_with_return_elements('test.meta',
                                      input_map={'input':_m.output},
                                      return_elements=[
                                                    # 'input:0',
                                                      'MY_LAYER/Conv2D:0'
                                                      ])
        with tf.compat.v1.Session() as sess:
                saver.restore(sess,'test')
                set_session(sess)
                out_mdl = tf.keras.models.Model(inputs=_m.input, outputs=out_op[0]) 
    return out_mdl


# main
save_ckpt()              # save name based checkpoint
meta_model = load_ckpt() # restore in keras model
oo = meta_model(np.random.rand(1,2,2,2)) # run the model
print(oo)

但我收到此错误:

Traceback (most recent call last):
  File "question2.py", line 38, in <module>
    meta_model = load_ckpt() # restore in keras model
  File "question2.py", line 32, in load_ckpt
    out_mdl = tf.keras.models.Model(inputs=_m.input, outputs=out_op[0]) 
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 146, in __init__
    super(Model, self).__init__(*args, **kwargs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 167, in __init__
    self._init_graph_network(*args, **kwargs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 270, in _init_graph_network
    base_layer_utils.create_keras_history(self._nested_outputs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py", line 184, in create_keras_history
    _, created_layers = _create_keras_history_helper(tensors, set(), [])
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py", line 229, in _create_keras_history_helper
    constants[i] = backend.function([], op_input)([])
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py", line 3740, in __call__
    outputs = self._graph_fn(*converted_inputs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1081, in __call__
    return self._call_impl(args, kwargs)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
    return self._call_flat(args, self.captured_inputs, cancellation_manager)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
    ctx=ctx)
  File "/home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.FailedPreconditionError:  Error while reading resource variable MY_LAYER/kernel from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/MY_LAYER/kernel)
     [[node MY_LAYER/Conv2D/ReadVariableOp (defined at /home/dionyssos/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_72]

Function call stack:
keras_scratch_graph

到目前为止我尝试过的事情

MY_LAYER/Conv2D:0中用input:0替换_import_meta_graph_with_return_elements()可使代码正常运行。

0 个答案:

没有答案