Tensorflow:如何从Tensorboard获得张量名称?

时间:2018-05-08 06:42:39

标签: python-3.x tensorflow

我从张量流检测模型动物园下载了ssd_mobilenet_v2_coco。我使用import_pb_to_tensorboard.py来显示Tensorboard上的结构。

我找到一个名为' image_tensor',this is the picture discribed in Tensorboard的节点。我想使用函数' get_tensor_by_name()'输入新图像并获得输出。然而,它失败了。

我试过了get_operation_by_name()' ,它也没有用。

以下是代码:

import tensorflow as tf

def one_image(im_path, model_path):
    sess= tf.Session()
    with sess.as_default():
        image_tensor = tf.image.decode_jpeg(tf.read_file(im_path), channels=3)
        saver = tf.train.import_meta_graph(model_path + "/model.ckpt.meta")
        saver.restore(sess, tf.train.latest_checkpoint(model_path))
        graph = tf.get_default_graph()

        # x = graph.get_tensor_by_name("import/image_tensor:0")
        # out_put = graph.get_tensor_by_name("import/detection_classes:0")

        x = graph.get_operation_by_name("import/image_tensor").outputs[0]
        outputs = graph.get_operation_by_name("import/detection_classes").outputs[0]
        out_put = sess.run(outputs, feed_dict={x: image_tensor.eval()})

        print(out_put)
        sess.close()

if __name__ == "__main__":
    one_image("testimg-4-resize.jpg", "ssd_mobilenet_v2_coco_2018_03_29")

这是KeyError:

KeyError: "The name 'import/image_tensor' refers to an Operation not in the graph."

我想知道如何从Tensorboard获取张量名称以及是否有另一种方法可以从' only-ckpts '中加载模型。

' 仅-ckpts '表示文件仅包含' model.ckpt.data-00000-of-00001 ' ,' model.ckpt.index '和' model.ckpt.meta '。

任何建议都将不胜感激。提前谢谢。

1 个答案:

答案 0 :(得分:1)

工具import_pb_to_tensorboard.py uses tf.import_graph_def to import the graph并使用默认name参数,即"import" as documented

您的代码通过tf.train.import_meta_graph导入图表,并使用默认的import_scope参数,该参数不会为导入的张量或操作名称添加前缀。很明显,您有两个选项可以纠正此错误:

  1. 执行以下操作代替import_meta_graph行:

    saver = tf.train.import_meta_graph(model_path + "/model.ckpt.meta",
                                       import_scope='import')
    
  2. 尝试按名称获取张量或操作时删除import/前缀:

    x = graph.get_tensor_by_name("image_tensor:0")
    out_put = graph.get_tensor_by_name("detection_classes:0")
    
    x = graph.get_operation_by_name("image_tensor").outputs[0]
    outputs = graph.get_operation_by_name("detection_classes").outputs[0]