将张量流模型冻结到.pb文件中

时间:2020-03-12 08:27:12

标签: tensorflow graph model anaconda meta

我正在尝试冻结流模式

在tenorflow中,从头开始训练是在4个文件之后创建的:

  1. model.ckpt-454501.data-00000-of-00001

  2. model.ckpt-454501.index

  3. model.ckpt-454501.meta

  4. 检查点

我想将它们(或仅需要的文件)转换为一个文件graph.pb 我用src:

import tensorflow as tf

meta_path = 'model.ckpt-454501.meta'  # Your .meta file
# output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))

    output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

我遇到了错误:

Traceback (most recent call last):
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_call
    return fn(*args)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1320, in _run_fn
    self._extend_graph()
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1381, in _extend_graph
    self._session, graph_def.SerializeToString(), status)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "freeze_graph1.py", line 11, in <module>
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py", line 1686, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 895, in run
    run_metadata_ptr)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1128, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1344, in _do_run
    options, run_metadata)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1363, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

Caused by op 'feature_fusion/Conv_9/biases/ExponentialMovingAverage', defined at:
  File "freeze_graph1.py", line 8, in <module>
    saver = tf.train.import_meta_graph(meta_path)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py", line 1838, in import_meta_graph
    **kwargs)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 660, in import_scoped_meta_graph
    producer_op_list=producer_op_list)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 316, in new_func
    return func(*args, **kwargs)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\importer.py", line 554, in import_graph_def
    op_def=op_def)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3160, in create_op
    op_def=op_def)
  File "C:\Users\Lerror\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'feature_fusion/Conv_9/biases/ExponentialMovingAverage': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
         [[Node: feature_fusion/Conv_9/biases/ExponentialMovingAverage = VariableV2[_class=["loc:@feature_fusion/Conv_9/biases"], container="", dtype=DT_FLOAT, shape=[1], shared_name="", _device="/device:GPU:0"]()]]

我不知道它是anaconda还是src

希望你能帮助我

谢谢

1 个答案:

答案 0 :(得分:0)

请您尝试以下代码

import tensorflow as tf

meta_path = 'model.ckpt-454501.meta'  # Your .meta file
# output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:

    with tf.device("/cpu:0"): 

    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('E:\OpenVino\SANGKV'))

    output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())