tensorflow输入管道:多次读取样本

时间:2016-12-16 16:40:15

标签: tensorflow

我试图在我的模型中实现一个从TFRecords二进制文件读取的输入管道; 每个二进制文件包含一个示例(图像,标签,我需要的其他东西)

我有一个带文件路径列表的文本文件;然后:

  1. 我将文本文件作为列表读取,我将其提供给string_input_producer()以生成队列;
  2. 我将队列提供给TFRecordReader,它读取序列化示例并解码二进制数据
  3. 我使用shuffle_batch()将示例安排到批处理
  4. 我使用批次来评估我的模型
  5. 问题是,事实证明同一个例子可以被多次读取,而一些例子可能根本就没有被访问过; 我将步数设置为图像总数除以批量大小;所以我希望在最后一步结束时访问所有输入示例,但事实并非如此;相反,有些是不止一次访问过的,有些则从不(随机);这使我的测试评估完全不可行

    如果有人暗示我做错了什么,请告诉我

    我的模型测试代码的简化版本如下; 谢谢!

    def my_input(file_list, batch_size)
    
        filename = []
        f = open(file_list, 'r')
        for line in f:
            filename.append(params.TEST_RECORDS_DATA_DIR + line[:-1])
    
        filename_queue = tf.train.string_input_producer(filename)
    
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image_raw': tf.FixedLenFeature([], tf.string),
                'label_raw': tf.FixedLenFeature([], tf.string),
                'name': tf.FixedLenFeature([], tf.string)
                })
    
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3)
        image = tf.reshape(image, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3))
        image = tf.cast(image, tf.float32)/255.0
        image = preprocess(image)
    
        label = tf.decode_raw(features['label_raw'], tf.uint8)
        label.set_shape(params.NUM_CLASSES)
    
        name = features['name']
    
        images, labels, image_names = tf.train.batch([image, label, name],
                batch_size=batch_size, num_threads=2,
                capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
    
        return images, labels, image_names
    
    
    def main()
    
        with tf.Graph().as_default():
    
            # call input operations
            images, labels, image_names = my_input(file_list=params.TEST_FILE_LIST, batch_size=params.BATCH_SIZE)
    
            # load a trained model and make predictions     
            prediction = infer(images, labels, image_names)
    
            with tf.Session() as sess:
    
                for step in range(params.N_STEPS):
                    prediction_values = sess.run([prediction])
                    # process output
    
        return
    

1 个答案:

答案 0 :(得分:0)

我的猜测是tf.train.string_input_producer(filename)被设置为无限期地生成文件名,如果你在多个(2)线程中批处理示例,可能是一个线程已经开始处理文件的情况第二次,而另一个还没有设法完成第一轮。要准确读取每个示例,请使用:

tf.train.string_input_producer(filename, num_epochs=1)

并在会话开始时初始化局部变量:

sess.run(tf.initialize_local_variables())