我正在尝试冻结玩具图,但是出现以下索引错误。什么是前进的好方法,老实说,我有什么建议吗?这是代码和堆栈跟踪:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.python.tools import freeze_graph
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax, name="output_node"))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model.fit(train_images, train_labels, epochs=1)
# GRAPH SAVING - '.pbtxt'
tf.train.write_graph(K.get_session().graph_def, 'out', 'my_graph_name_graph.pbtxt')
# GRAPH SAVING - '.chkp'
# KEY: This method saves the graph at it's last checkpoint (hence '.chkp')
tf.train.Saver().save(K.get_session(), 'out/my_graph_name.chkp')
# GRAPH SAVING - '.bytes'
freeze_graph.freeze_graph('out/my_graph_name_graph.pbtxt', None, False,
'out/my_graph_name.chkp', "output_node",
"save/restore_all", "save/Const:0",
'out/frozen_my_graph_name.bytes', True, "")
print("done!")
Stacktrace:
File "test.py", line 35, in <module>
'out/frozen_' + GRAPH_NAME + '.bytes', True, "")
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 363, in freeze_graph
checkpoint_version=checkpoint_version)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\tools\freeze_graph.py", line 190, in freeze_graph_with_def_protos
var_list=var_list, write_version=checkpoint_version)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 1102, in __init__
self.build()
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 1114, in build
self._build(self._filename, build_save=True, build_restore=True)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 1151, in _build
build_save=build_save, build_restore=build_restore)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 773, in _build_internal
saveables = self._ValidateAndSliceInputs(names_to_saveables)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 680, in _ValidateAndSliceInputs
for converted_saveable_object in self.SaveableObjectsForOp(op, name):
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 654, in SaveableObjectsForOp
variable, "", name)
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\training\saver.py", line 128, in __init__
self.handle_op = var.op.inputs[0]
File "C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env\lib\site-packages\tensorflow\python\framework\ops.py", line 2128, in __getitem__
return self._inputs[i]
IndexError: list index out of range
下面是我的包裹:
(env) C:\env>conda list
# packages in environment at C:\Users\bsmit\AppData\Local\Continuum\anaconda3\envs\env:
#
# Name Version Build Channel
_tflow_select 2.3.0 mkl
absl-py 0.7.0 py36_0
astor 0.7.1 py36_0
backcall 0.1.0 py36_0
blas 1.0 mkl
bleach 3.1.0 py36_0
ca-certificates 2019.1.23 0
certifi 2018.11.29 py36_0
colorama 0.4.1 py36_0
decorator 4.3.2 py36_0
entrypoints 0.3 py36_0
gast 0.2.2 py36_0
grpcio 1.16.1 py36h351948d_1
h5py 2.9.0 py36h5e291fa_0
hdf5 1.10.4 h7ebc959_0
icc_rt 2019.0.0 h0cc432a_1
icu 58.2 ha66f8fd_1
intel-openmp 2019.1 144
ipykernel 5.1.0 py36h39e3cac_0
ipython 7.3.0 py36h39e3cac_0
ipython_genutils 0.2.0 py36h3c5d0ee_0
ipywidgets 7.4.2 py36_0
jedi 0.13.3 py36_0
jinja2 2.10 py36_0
jpeg 9b hb83a4c4_2
jsonschema 2.6.0 py36h7636477_0
jupyter 1.0.0 py36_7
jupyter_client 5.2.4 py36_0
jupyter_console 6.0.0 py36_0
jupyter_core 4.4.0 py36_0
keras-applications 1.0.6 py36_0
keras-base 2.2.4 py36_0
keras-preprocessing 1.0.5 py36_0
libmklml 2019.0.3 0
libpng 1.6.36 h2a8f88b_0
libprotobuf 3.6.1 h7bd577a_0
libsodium 1.0.16 h9d3ae62_0
m2w64-gcc-libgfortran 5.3.0 6
m2w64-gcc-libs 5.3.0 7
m2w64-gcc-libs-core 5.3.0 7
m2w64-gmp 6.1.0 2
m2w64-libwinpthread-git 5.0.0.4634.697f757 2
markdown 3.0.1 py36_0
markupsafe 1.1.1 py36he774522_0
mistune 0.8.4 py36he774522_0
mkl 2019.1 144
mkl_fft 1.0.10 py36h14836fe_0
mkl_random 1.0.2 py36h343c172_0
msys2-conda-epoch 20160418 1
nbconvert 5.3.1 py36_0
nbformat 4.4.0 py36h3a5bc1b_0
notebook 5.7.4 py36_0
numpy 1.16.2 py36h19fb1c0_0
numpy-base 1.16.2 py36hc3f5095_0
openssl 1.1.1b he774522_0
pandoc 2.2.3.2 0
pandocfilters 1.4.2 py36_1
parso 0.3.4 py36_0
pickleshare 0.7.5 py36_0
pip 19.0.3 py36_0
prometheus_client 0.6.0 py36_0
prompt_toolkit 2.0.9 py36_0
protobuf 3.6.1 py36h33f27b4_0
pygments 2.3.1 py36_0
pyqt 5.9.2 py36h6538335_2
pyreadline 2.1 py36_1
python 3.6.8 h9f7ef89_7
python-dateutil 2.8.0 py36_0
pywinpty 0.5.5 py36_1000
pyyaml 3.13 py36hfa6e2cd_0
pyzmq 18.0.0 py36ha925a31_0
qt 5.9.7 vc14h73c81de_0
qtconsole 4.4.3 py36_0
scipy 1.2.1 py36h29ff71c_0
send2trash 1.5.0 py36_0
setuptools 40.8.0 py36_0
sip 4.19.8 py36h6538335_0
six 1.12.0 py36_0
sqlite 3.26.0 he774522_0
tensorboard 1.12.2 py36h33f27b4_0
tensorflow 1.12.0 mkl_py36h4f00353_0
tensorflow-base 1.12.0 mkl_py36h81393da_0
termcolor 1.1.0 py36_1
terminado 0.8.1 py36_1
testpath 0.4.2 py36_0
tornado 5.1.1 py36hfa6e2cd_0
traitlets 4.3.2 py36h096827d_0
vc 14.1 h0510ff6_4
vs2015_runtime 14.15.26706 h3a45250_0
wcwidth 0.1.7 py36h3d5aa90_0
webencodings 0.5.1 py36_1
werkzeug 0.14.1 py36_0
wheel 0.33.1 py36_0
widgetsnbextension 3.4.2 py36_0
wincertstore 0.2 py36h7fe50ca_0
winpty 0.4.3 4
yaml 0.1.7 hc54c509_2
zeromq 4.3.1 h33f27b4_3
zlib 1.2.11 h62dcd97_3
您的帖子似乎主要是代码;请添加更多详细信息。看来您的帖子大部分是代码;请添加更多详细信息。
答案 0 :(得分:0)
您的模型中没有任何输入。 keras模型需要某种形式的输入,可以是一层,也可以是第一层的参数。
查看this链接以获取如何在MNIST上使用keras的清晰示例。几乎完全相同的代码,但请注意以下一行
model.add(Dense(512, activation='relu', input_shape=(784,)))
input_shape
参数就是您所缺少的。
答案 1 :(得分:0)
让我们free_graph将输入参数input_pb更改为input_meta_graph。它对我有用。