我正在尝试对图像分类器进行TF实现(使用py3.5和Windows 10,TF 0.12),所以我重新使用现有模型as described here,但没有所有奇怪的Bazel东西。修复this line上的py2-to-3错误(将keys()
包裹在list()
中)后,它在我的10个不同类别的文件夹上运行得很好。但是,缺乏表现;培训成功率约为83%,验证集最多不超过60%。所以我想从vgg16模型(这是我以前在Caffe / ubuntu中使用过的模型)进行一些转移学习; one I've found is here已准备好下载。
我现在的问题是,如何在Tensorflow中加载.tfmodel文件?该脚本期望下载tar.gz,足够公平。它显然包含一个名为classify_image_graph_def.pb
的文件,它不是.tfmodel文件。查看some example code我发现加载.tfmodel文件非常简单,因此我修改了create_inception_graph
函数以直接指向vgg16-20160129.tfmodel
文件。运行此操作后,我收到此错误:
File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 450, in import_graph_def
ret.append(name_to_op[operation_name].outputs[output_index])
KeyError: 'pool_3/_reshape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "retrain.py", line 995, in <module>
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "retrain.py", line 713, in main
create_inception_graph())
File "retrain.py", line 235, in create_inception_graph
RESIZED_INPUT_TENSOR_NAME]))
File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 453, in import_graph_def
'Requested return_element %r not found in graph_def.' % name)
ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.
这是加载代码:
def create_inception_graph():
""""Creates a graph from saved GraphDef file and returns a Graph object.
Returns:
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
with tf.Session() as sess:
#model_filename = os.path.join(
# FLAGS.model_dir, 'classify_image_graph_def.pb')
model_filename = os.path.join(
FLAGS.model_dir, 'vgg16-20160129.tfmodel')
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
tf.import_graph_def(graph_def, name='', return_elements=[
BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
RESIZED_INPUT_TENSOR_NAME]))
return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
在tf.import_graph_def
调用中似乎出现了一些错误,但奇怪的是没有该函数的文档。我正在尝试甚至可能吗?有一大堆瓶颈张量和jpeg数据和调整大小的输入张量名称,我不知道它们的用途是什么,这个例子没有复制。
答案 0 :(得分:0)
就像它在跟踪中所说的那样,存在一个ValueError。
ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.
您的图表文件 - 'vgg16-20160129.tfmodel'中没有此节点。 重新检查变量BOTTLENECK_TENSOR_NAME,JPEG_DATA_TENSOR_NAME,RESIZED_INPUT_TENSOR_NAME。这些应该与您的网络架构相对应。