如何使用Tensorflow在CNN上重新加载权重和偏差?

时间:2018-05-20 21:09:54

标签: python tensorflow

我训练了一个带有tensorflow的模型并导出了元图。然后,当导入训练的图形并加载已保存的变量时,会发生以下错误:

<div id="content-container" class="container">
  <div id="box1">
    <h1>Box 1</h1>
  </div>
  <div id="box2">
    <h1>Box 2</h1>
  </div>
  <div id="box3">
    <h1>Box 3</h1>
  </div>
</div>

怎么办?另外,有没有办法可视化我创建的图形?

修改

完整的代码是:

"C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\python.exe" C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py
Traceback (most recent call last):
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1327, in _do_call
return fn(*args)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1312, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1420, in _call_tf_sessionrun
status, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 516, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py", line 62, in <module>
print(sess.run('y_pred:0'))
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 905, in run
run_metadata_ptr)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1140, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1321, in _do_run
run_metadata)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\client\session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'x', defined at:
File "C:/Users/fredd/PycharmProjects/CNN/detectionDemo.py", line 60, in <module>
saver = tf.train.import_meta_graph('results/steering_model.meta')
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\training\saver.py", line 1927, in import_meta_graph
**kwargs)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 741, in import_scoped_meta_graph
producer_op_list=producer_op_list)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(*args, **kwargs)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\importer.py", line 577, in import_graph_def
op_def=op_def)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
op_def=op_def)
File "C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python36_64\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x' with dtype float and shape [16,96,128,3]
 [[Node: x = Placeholder[dtype=DT_FLOAT, shape=[16,96,128,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

虽然CNN的完整代码是:

sess = tf.Session()
saver = tf.train.import_meta_graph('results/steering_model.meta')
saver.restore(sess, 'results/steering_model')
print(sess.run('y_pred:0'))

网络已成功保存,但导入后我无法使用之前保存的任何变量。

1 个答案:

答案 0 :(得分:2)

该错误与您的保存/加载无关,但与session.run电话无关。您保存/加载的图表有一个占位符(x),您需要使用feed_dict的{​​{1}}参数进行提取,就像您手动构建它一样。您可以使用Session.run

获取它
graph.get_tensor_by_name