我有一个tfrecords
文件,我希望从中创建批量数据。我正在使用tf.train.shuffle_batch()
创建一个批处理。在我的训练中,我想打电话给批次并传递它们。这就是我被困的地方。我读到TFRecordReader()
的poistion被保存在图的状态中,下一个例子是从后续位置读取的。问题是我无法确定如何加载下一批。我使用以下代码创建批次。
def read_and_decode_single_example(filename):
filename_queue = tf.train.string_input_producer([filename], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'context': tf.FixedLenFeature([160], tf.int64),
'context_len': tf.FixedLenFeature([1], tf.int64),
'utterance': tf.FixedLenFeature([160], tf.int64),
'utterance_len': tf.FixedLenFeature([1], tf.int64),
'label': tf.FixedLenFeature([1], tf.int64)
})
contexts = features['context']
context_lens = features['context_len']
utterances = features['utterance']
utterance_lens = features['utterance_len']
labels = features['label']
return contexts, context_lens, utterances, utterance_lens, labels
contexts, context_lens, utterances, utterance_lens, labels = \
read_and_decode_single_example('data/train.tfrecords')
contexts_batch, context_lens_batch, \
utterances_batch, utterance_lens_batch, \
labels_batch = tf.train.shuffle_batch([contexts, context_lens, utterances,
utterance_lens, labels],
batch_size=batch_size,
capacity=3*batch_size,
min_after_dequeue=batch_size)
这给了我一批数据。我想使用feed_dict
范例来传递批次进行培训,其中每次迭代都会传入新批次。如何加载这些批次?调用read_and_decode
和tf.train.shuffle_batch
会再次调用下一批吗?
答案 0 :(得分:1)
read_and_decode_single_example()
函数为网络创建(子)图,用于加载数据;你只打电话一次。它可能更恰当地称为build_read_and_decode_single_example_graph()
,但这有点长。
“魔力”在于多次评估(即使用)_batch
张量,例如
batch_size = 100
# ...
with tf.Session() as sess:
# get the first batch of 100 values
first_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])
# second batch of different 100 values
second_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])
# etc.
当然,您不是手动从会话中获取这些值,而是将它们反馈到网络的其他部分。机制是相同的:每当直接或间接获取这些张量中的一个时,批处理机制将负责每次为您提供一个新的批次(具有不同的值)。