我如何使用tensorflow .pb文件?

时间:2018-07-12 07:35:23

标签: tensorflow

我有一个Tensorflow文件AlexNet.pb,我正在尝试加载它,然后对我拥有的图像进行分类。我已经搜索了几个小时,但是仍然找不到找到它然后分类图像的方法。它是如此明显并且我是如此愚蠢,因为似乎没有人提供加载和运行.pb文件的简单示例。

2 个答案:

答案 0 :(得分:2)

这取决于protobuf文件的创建方式。

如果.pb文件是以下结果:

    # Create a builder to export the model
    builder = tf.saved_model.builder.SavedModelBuilder("export")
    # Tag the model in order to be capable of restoring it specifying the tag set
    builder.add_meta_graph_and_variables(sess, ["tag"])
    builder.save()

您必须知道该模型的标记方式,并使用tf.saved_model.loader.load方法将保存的图形加载到当前的空白图形中。

如果模型已被冻结,则必须手动将二进制文件加载到内存中:

with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

graph = tf.get_default_graph()
tf.import_graph_def(graph_def, name="prefix")

在两种情况下,您都必须知道输入张量的名称和要执行的节点的名称:

例如,如果您的输入张量是名为batch_的占位符,而您要执行的节点是您必须命名为dense/BiasAdd:0的节点

    batch = graph.get_tensor_by_name('batch:0')
    prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

    values = sess.run(prediction, feed_dict={
        batch: your_input_batch,
    })

答案 1 :(得分:-1)

您可以使用opencv加载.pb模型, 例如。

net = cv2.dnn.readNet("model.pb")

确保您使用的是Opencv的特定版本-OpenCV 3.4.2或OpenCV 4