read_cifar10()例程如何在TensorFlow教程中返回除第一个对象之外的任何内容?

时间:2016-01-27 12:39:51

标签: python tensorflow

TensorFlow有CIFAR-10教程,is discussed here。 Python is here中的源代码。

它有read_cifar10() routine here,用于从二进制文件中读取样本。

我无法理解,它是如何运作的。怀疑这与TensorFlow延迟性质有某种关系,但无法弄清楚如何。

在某些时候,例程执行以下操作:

$http

我在这里看到,从头开始创建一个新的阅读器,然后这个阅读器指向文件名队列。

# Read a record, getting filenames from the filename_queue. No # header or footer in the CIFAR-10 format, so we leave header_bytes # and footer_bytes at their default of 0. reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key, value = reader.read(filename_queue) 来电返回了多少个样本?

稍后,在read方法内部,代码执行以下操作:

distorted_inputs()

这里print ('Filling queue with %d CIFAR images before starting to train. ' 'This will take a few minutes.' % min_queue_examples) # Generate a batch of images and labels by building up a queue of examples. return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples) 是普通的Python调用,不是延迟的,所以注释假定会立即提取20000条记录。

怎么会发生?我到处都只看到每个记录的逻辑。它如何在许多记录中成倍增加?

1 个答案:

答案 0 :(得分:2)

TLDR; reader.read仅向计算图添加read操作,实际执行发生在session.run期间,由while(True): session.run(...)start_queue_runners类型的循环中的单独线程完成}}

长版: 这是“输入管道”的一部分,由于读取/预取需要异步发生以避免阻塞这一事实使其变得复杂。官方如何描述输入管道是here

更具体地说,reader.read添加了一个操作来将单个记录读取到计算图。然后,此操作将输入_generate_image_and_label_batch内创建的shuffle_batch。到目前为止还没有阅读。 shuffle_batch操作创建了一个解耦输入流的队列,在某种意义上,可以使用不同的session.run调用异步完成队列之前和队列之后的部分评估,并提供队列在中间缓冲。此外,shuffle_batch操作将作为GraphKeys.QUEUE_RUNNERS集合的一部分提供给队列的操作注册。

train()内,操作tf.start_queue_runners将创建与GraphKeys.QUEUE_RUNNERS集合中注册的入队操作相对应的多个线程,并开始在循环中对它们进行评估。 reader.read的结果将流经其他操作,直到到达shuffle_batch队列并保存在其内存缓冲区中。

shuffle_batch之后的图形部分将由sess.run([train_op, loss])命令启动的主Python线程驱动。该线程将收集保存在shuffle_batch队列上的一批示例,并将其向前传播。

以下是手动输入输入队列而不是使用队列运行程序的示例。

queue_dtype = np.int32
queue_capacity = 2
values_queue = tf.FIFOQueue(capacity=queue_capacity, dtypes=queue_dtype)
size_op = values_queue.size()
value_placeholder = tf.placeholder(dtype=queue_dtype)
enqueue_op = values_queue.enqueue(value_placeholder)
dequeue_op = values_queue.dequeue()
close_op = values_queue.close()

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

# add two elements onto the queue
sess.run([enqueue_op], {value_placeholder:2})
sess.run([enqueue_op], {value_placeholder:3})
# if you uncomment the next line, you'll hang because queue is full
# sess.run([enqueue_op], {value_placeholder:4})

# close the queue. This means 3rd read will throw OutOfRangeError instead of
# hanging until queue is replenished
sess.run([close_op])
print('queue has %d/%d entries' % (sess.run([size_op])[0], queue_capacity))

# take two elements off the queue
fancy_computation = tf.square(dequeue_op)
print('Computation result %d' %(sess.run([fancy_computation])[0]))
print('queue has %d/%d entries' % (sess.run([size_op])[0], queue_capacity))
print('Computation result %d' %(sess.run([fancy_computation])[0]))
print('queue has %d/%d entries' % (sess.run([size_op])[0], queue_capacity))
print('Computation result %d' %(sess.run([fancy_computation])[0]))
print('queue has %d/%d entries' % (sess.run([size_op])[0], queue_capacity))

如果你运行它应该看到什么

queue has 2/2 entries
Computation result 4
queue has 1/2 entries
Computation result 9
queue has 0/2 entries
---------------------------------------------------------------------------
OutOfRangeError