Tensorflow - 来自tf.train.shuffle_batch的下一批数据

时间:2017-02-01 11:07:33

标签: python tensorflow

我有一个tfrecords文件,我希望从中创建批量数据。我正在使用tf.train.shuffle_batch()创建一个批处理。在我的训练中,我想打电话给批次并传递它们。这就是我被困的地方。我读到TFRecordReader()的poistion被保存在图的状态中,下一个例子是从后续位置读取的。问题是我无法确定如何加载下一批。我使用以下代码创建批次。

def read_and_decode_single_example(filename):
    filename_queue = tf.train.string_input_producer([filename], num_epochs=1)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'context': tf.FixedLenFeature([160], tf.int64),
            'context_len': tf.FixedLenFeature([1], tf.int64),
            'utterance': tf.FixedLenFeature([160], tf.int64),
            'utterance_len': tf.FixedLenFeature([1], tf.int64),
            'label': tf.FixedLenFeature([1], tf.int64)
        })

    contexts = features['context']
    context_lens = features['context_len']
    utterances = features['utterance']
    utterance_lens = features['utterance_len']
    labels = features['label']

    return contexts, context_lens, utterances, utterance_lens, labels

contexts, context_lens, utterances, utterance_lens, labels = \
    read_and_decode_single_example('data/train.tfrecords')

contexts_batch, context_lens_batch, \
    utterances_batch, utterance_lens_batch, \
    labels_batch = tf.train.shuffle_batch([contexts, context_lens, utterances,
                                          utterance_lens, labels],
                                          batch_size=batch_size,
                                          capacity=3*batch_size,
                                          min_after_dequeue=batch_size)

这给了我一批数据。我想使用feed_dict范例来传递批次进行培训,其中每次迭代都会传入新批次。如何加载这些批次?调用read_and_decodetf.train.shuffle_batch会再次调用下一批吗?

1 个答案:

答案 0 :(得分:1)

read_and_decode_single_example()函数为网络创建(子)图,用于加载数据;你只打电话一次。它可能更恰当地称为build_read_and_decode_single_example_graph(),但这有点长。

“魔力”在于多次评估(即使用)_batch张量,例如

batch_size = 100
# ...

with tf.Session() as sess:
    # get the first batch of 100 values
    first_batch = sess.run([contexts_batch, context_lens_batch,
                            utterances_batch, utterance_lens_batch,
                            labels_batch])

    # second batch of different 100 values
    second_batch = sess.run([contexts_batch, context_lens_batch,
                            utterances_batch, utterance_lens_batch,
                            labels_batch])
    # etc.

当然,您不是手动从会话中获取这些值,而是将它们反馈到网络的其他部分。机制是相同的:每当直接或间接获取这些张量中的一个时,批处理机制将负责每次为您提供一个新的批次(具有不同的值)。