如何抑制TensorFlow对象检测API的输出?

时间:2018-12-20 11:50:13

标签: tensorflow object-detection-api

我正在使用TensorFlow对象检测API。我训练了一个模型并提取了一个图,但是在推理过程中,我得到的消息看起来像是图层和训练参数的打印。这是一种味道:

  

优化融合的批处理规范节点名称:“ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / BatchNorm / FusedBatchNorm”

     

op:“ FusedBatchNorm”

     

输入:“ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / Conv2D”

     

输入:“ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / BatchNorm / gamma”

     

输入:“ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / BatchNorm / beta”   输入:   “ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / BatchNorm / moving_mean”

     

输入:   “ SecondStageFeatureExtractor / InceptionV2 / Mixed_5c / Branch_3 / Conv2d_0b_1x1 / BatchNorm / moving_variance”

     

设备:“ / job:localhost /副本:0 /任务:0 /设备:GPU:0”

     

attr {键:   “ T”值{       类型:DT_FLOAT}}

     

attr {键:“ data_format”值{       s:“ NHWC”}}

还有更多

如何隐藏这些消息?

我已经尝试过os.environ['TF_CPP_MIN_LOG_LEVEL']tf.logging.set_verbosity(tf.logging.ERROR)。它们不会抑制对象检测API消息。

我的推断代码类似于显示的here。这是我的代码:

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')


images = glob.glob(TEST_IMAGE_PATHS)
fig, ax = plt.subplots(1)

with detection_graph.as_default():
    with tf.Session() as sess:
        # Get handles to input and output tensors
        ops = tf.get_default_graph().get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        tensor_dict = {}

    for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes']:
        tensor_name = key + ':0'

        if tensor_name in all_tensor_names:
            tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)

    image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

    # Run inference
    for image_filename in images:
        image = nd.imread(image_filename)
        image = imresize(image, 50)

        output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)})

        # all outputs are float32 numpy arrays, so convert types as appropriate
        boxes = output_dict['detection_boxes'][0]
        plt.imshow(image)
        shape = image.shape
        for i in range(int(output_dict['num_detections'][0])):
            box_y = boxes[i][0]*shape[0]
            box_x = boxes[i][1]*shape[1]
            box_h = (boxes[i][2] - boxes[i][0])*shape[0]
            box_w = (boxes[i][3] - boxes[i][1])*shape[1]

            ax.add_patch(patches.Rectangle((box_x, box_y), box_w, box_h, linewidth=1, edgecolor='r', facecolor='none'))
        fig.show()
        plt.waitforbuttonpress()
        ax.clear()

0 个答案:

没有答案