Tensorflow:从vgg16 .tfmodel文件转移学习

时间:2016-12-19 06:00:19

标签: python tensorflow

我正在尝试对图像分类器进行TF实现(使用py3.5和Windows 10,TF 0.12),所以我重新使用现有模型as described here,但没有所有奇怪的Bazel东西。修复this line上的py2-to-3错误(将keys()包裹在list()中)后,它在我的10个不同类别的文件夹上运行得很好。但是,缺乏表现;培训成功率约为83%,验证集最多不超过60%。所以我想从vgg16模型(这是我以前在Caffe / ubuntu中使用过的模型)进行一些转移学习; one I've found is here已准备好下载。

我现在的问题是,如何在Tensorflow中加载.tfmodel文件?该脚本期望下载tar.gz,足够公平。它显然包含一个名为classify_image_graph_def.pb的文件,它不是.tfmodel文件。查看some example code我发现加载.tfmodel文件非常简单,因此我修改了create_inception_graph函数以直接指向vgg16-20160129.tfmodel文件。运行此操作后,我收到此错误:

File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 450, in import_graph_def
    ret.append(name_to_op[operation_name].outputs[output_index])
KeyError: 'pool_3/_reshape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "retrain.py", line 995, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "retrain.py", line 713, in main
    create_inception_graph())
  File "retrain.py", line 235, in create_inception_graph
    RESIZED_INPUT_TENSOR_NAME]))
  File "C:\Users\User\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\importer.py", line 453, in import_graph_def
    'Requested return_element %r not found in graph_def.' % name)
ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

这是加载代码:

def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.
  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Session() as sess:
    #model_filename = os.path.join(
    #    FLAGS.model_dir, 'classify_image_graph_def.pb')
    model_filename = os.path.join(
        FLAGS.model_dir, 'vgg16-20160129.tfmodel')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor

tf.import_graph_def调用中似乎出现了一些错误,但奇怪的是没有该函数的文档。我正在尝试甚至可能吗?有一大堆瓶颈张量和jpeg数据和调整大小的输入张量名称,我不知道它们的用途是什么,这个例子没有复制。

1 个答案:

答案 0 :(得分:0)

就像它在跟踪中所说的那样,存在一个ValueError。

ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.

您的图表文件 - 'vgg16-20160129.tfmodel'中没有此节点。 重新检查变量BOTTLENECK_TENSOR_NAME,JPEG_DATA_TENSOR_NAME,RESIZED_INPUT_TENSOR_NAME。这些应该与您的网络架构相对应。