TensorFlow:预处理操作也会在图表中冻结吗?

时间:2017-05-09 08:30:17

标签: graph tensorflow deep-learning preprocessor

我相信在训练之后,保存到检查点的模型不包含任何预处理操作,因为在检查检查点模型时,可用的操作从模型的输入开始(而不是在模型的输入之前的预处理操作)模型输入)。

但是,当冻结从点文件恢复的图形时,图形还有其他预处理操作,预处理操作是否也会被冻结?我在图表中包含了一个预处理操作以测试时间,并打算将图表与检查点模型一起冻结,但结果似乎在这两种情况下变化很大:

  1. 将原始图像放入冻结图形中,并使用冻结图形中包含的预处理操作 - >非常非常差的准确性,好像没有进行预处理一样。

  2. 在将预处理的图像放入不包含任何预处理操作的冻结图之前,首先预处理图像 - >结果按预期工作,准确度很高。

  3. 所以我的问题是预处理操作是否被有效冻结,或者是否建议仅在测试时预处理图像,以便我们可以保留冻结图仅用于执行推理(而不是任何预处理操作)?我的目的是在图表中包含预处理操作以使其更方便,但似乎这种方法不起作用。

    TensorFlow对此类工作流程的看法是什么?预处理应该在图中完成并冻结,还是应该是冻结图之外的单独任务?

    以下是我打算将预处理操作放在图表中并将其全部冻结的方法:

    with tf.Graph().as_default() as graph:
    
        # image = tf.placeholder(shape=[None, None, 3], dtype=tf.float32, name = 'Placeholder_only')
        # preprocessed_image = inception_preprocessing.preprocess_for_eval(image, 299, 299)
        # preprocessed_image = tf.expand_dims(preprocessed_image, 0)
        img_array = tf.placeholder(dtype=tf.float32, shape=[None,None,3], name='Placeholder_only')
        preprocessed_image = inception_preprocessing.preprocess_for_eval(img_array, 299, 299)
        preprocessed_image = tf.expand_dims(preprocessed_image, 0, name='expand_preprocessed_img')
    
        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            logits, end_points = inception_resnet_v2(preprocessed_image, num_classes = 5, is_training = False)
    
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
    
        #Setup graph def
        input_graph_def = graph.as_graph_def()
        output_node_names = "InceptionResnetV2/Logits/Predictions"
        output_graph_name = "./frozen_flowers_model_IR2_with_preprocesssing.pb"
    
        with tf.Session() as sess:
            saver.restore(sess, checkpoint_file)
    
            # count=0
            # for op in graph.get_operations():
            #     print (op.name)
            #     count+=1
            #     if count==50:
            #         assert False
    
            #Exporting the graph
            print ("Exporting graph...")
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(","))
    
            with tf.gfile.GFile(output_graph_name, "wb") as f:
                f.write(output_graph_def.SerializeToString())
    

0 个答案:

没有答案