Tensorflow:如何使用持久性会话评估测试数据?

时间:2018-05-05 05:02:54

标签: python tensorflow

最近,我遇到了张量流评估问题。我的要求是我想使用训练有素的网络来评估一些动态测试数据

动态测试数据意味着我最初只有一个测试数据,随后根据第一个测试数据的评估结果生成第二个测试数据。

一个简单的解决方案是我可以编写一个评估函数来逐一评估测试数据,例如

def evaluation(testing_data):
    input = tf.placeholder(tf.float32, [None, params['height'], params['width'], params['depth']])
    logits = inference(input)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir + '/'))
        prediction = self.session.run([logits['pred']], feed_dict={input: testing_data})
    return prediction

但是,这个解决方案非常慢,因为我需要在调用此函数时初始化Tensorflow图形(显然,图形只需要初始化一次,因为我不会更改任何参数或其他任何参数训练有素的网络)。

为了提高效率,我重写评估功能如下。

def new_evaluation(testing_data, first_initialize):
    input = tf.placeholder(tf.float32, [None, params['height'], params['width'], params['depth']])
    if first_initialize:
        logits = inference(input)
        saver = tf.train.Saver()
        # use a global variable to save session
        global_session = tf.Session()   
        saver.restore(global_session, tf.train.latest_checkpoint(checkpoint_dir + '/'))
        prediction = global_session.run([logits['pred']], feed_dict={input: testing_data})
    else:
        prediction = global_session.run([logits['pred']], feed_dict={input: testing_data})
    return prediction 
# I will manually close the global session when I finish all of jobs.

new_evaluation函数的意图很明显:

  1. 如果是第一次评估,我们会创建张量流推理图并恢复经过训练的参数。

  2. 如果我们已经初始化了会话,我们会直接使用它来评估新的测试数据。

  3. 问题:

    这个new_evaluation函数在进入 first_initialize 分支时工作正常,但是当我第二次调用它时会弹出一些错误。这是日志

    tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,256,256,3]
         [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[?,256,256,3], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
         [[Node: prediction/_431 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1275_prediction", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
    

    这个日志似乎告诉我们数据类型不一致。但是,第二测试数据与第一测试数据具有相同的类型和形状。

    为什么第一个有效但第二个有崩溃?

    任何想法或更好的解决方案?谢谢!

1 个答案:

答案 0 :(得分:0)

我已经想出如何有效地解决这个问题。我在这里提供我的解决方案以防其他人需要它。

我们可以在第一次初始化后使用协同程序来保存会话上下文。这是代码

def evaluation(self, checkpoint_dir):
    input = tf.placeholder(tf.float32, [None, self.params['height'], self.params['width'], self.params['depth']])
    logits = inference(input)
    saver = tf.train.Saver(max_to_keep=2)
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir + '/'))
        while not global_close:
            prediction = sess.run([self.pred], feed_dict={input: global_input})
            yield prediction
# we have two global variables. global_close control the exit and global_input save the input


def your_task_process():
    global_close = False
    evaluator = evaluation(checkpoint_dir)  # create a generator

    # do you stuff

    global_input = your_input1      # assign a testing data
    inference_value1 = next(evaluator)

    # do you stuff

    global_input = your_input2      # assign the second testing data
    inference_value2 = next(evaluator)

    ...

    # when finish, terminate the generator
    global_close = True
    next(evaluator)