如何仅从.pb文件加载网络的前n层

时间:2018-10-04 13:53:00

标签: python tensorflow machine-learning

我想要的东西:一个protobuf文件,其中包含经过预训练的AlexNet的所有层,直到pool5层。

我所拥有的:我下载了AlexNet here的权重文件,然后 使用this代码将其转换为模型的protobuf文件和冻结的protobuf文件。 我用以下代码加载了生成的protobuf文件:

import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.contrib import graph_editor as editor

GRAPH_PB_PATH = 'alexnet.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   writer = tf.summary.FileWriter('logs', sess.graph)
   writer.close()
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

现在,我想扔掉pool5层之后的所有层,以使网络的输入是图像,而输出是pool5返回的值(即某些向量)。我想保存结果,现在将较小的网络再次保存到protobuf文件中。 那么,如何删除不必要的图层? 预先感谢!

1 个答案:

答案 0 :(得分:1)

graph_def = tf.GraphDef()
with open('alexnet.pb', 'rb') as f:
    graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
    importer.import_graph_def(graph_def, name='')

new_model = tf.GraphDef()

with tf.Session(graph=graph) as sess:    
    for n in sess.graph_def.node:            
        nn = new_model.node.add()
        nn.CopyFrom(n)
        if n.op.name == 'pool5':
            break;

tf.train.write_graph(new_model, '.', 'cut_model.pb', as_text=False)