我有一些序列到序列场景的训练示例,它们以tf.train.SequenceExample
形式存储在一个(或多个)文件TFRecordWriter
中。我想阅读,解码它们并将混乱的批次送入我的网络。我一直在努力使用文档和一些在这里和那里找到的教程,但我无法用这些东西做任何事情。我正在研究一个自包含的例子,如下所示。
import random
import tensorflow as tf
from six.moves import xrange
MIN_LEN = 6
MAX_LEN = 12
NUM_EXAMPLES = 20
BATCH_SIZE = 3
PATH = 'ciaone.tfrecords'
MIN_AFTER_DEQUEUE = 10
NUM_THREADS = 2
SAFETY_MARGIN = 1
CAPACITY = MIN_AFTER_DEQUEUE + (NUM_THREADS + SAFETY_MARGIN) * BATCH_SIZE
def generate_example():
# fake examples which are just useful to have a quick visualization.
# The input is a sequence of random numbers.
# The output is a sequence made of those numbers from the
# input sequence which are greater or equal then the average.
length = random.randint(MIN_LEN, MAX_LEN)
input_ = [random.randint(0, 10) for _ in xrange(length)]
avg = sum([1.0 * item for item in input_]) / len(input_)
output = [item for item in input_ if item >= avg]
return input_, output
def encode(input_, output):
length = len(input_)
example = tf.train.SequenceExample(
context=tf.train.Features(
feature={
'length': tf.train.Feature(
int64_list=tf.train.Int64List(value=[length]))
}),
feature_lists=tf.train.FeatureLists(
feature_list={
'input': tf.train.FeatureList(
feature=[
tf.train.Feature(
int64_list=tf.train.Int64List(value=[item]))
for item in input_]),
'output': tf.train.FeatureList(
feature=[
tf.train.Feature(
int64_list=tf.train.Int64List(value=[item]))
for item in output])
}
)
)
return example
def decode(example):
context_features = {
'length': tf.FixedLenFeature([], tf.int64)
}
sequence_features = {
'input': tf.FixedLenSequenceFeature([], tf.int64),
'output': tf.FixedLenSequenceFeature([], tf.int64)
}
ctx, seq = tf.parse_single_sequence_example(
example, context_features, sequence_features)
input_ = seq['input']
output = seq['output']
return input_, output
if __name__ == '__main__':
# STEP 1. -- generate a dataset.
with tf.python_io.TFRecordWriter(PATH) as writer:
for _ in xrange(NUM_EXAMPLES):
record = encode(*generate_example())
writer.write(record.SerializeToString())
with tf.Session() as sess:
queue = tf.train.string_input_producer([PATH])
reader = tf.TFRecordReader()
_, value = reader.read(queue)
input_, output = decode(value)
# HERE I AM STUCK!
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
try:
while True:
# do something...
except tf.errors.OutOfRangeError, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
coord.request_stop()
coord.join(threads)
任何人都可以建议我如何进行? 提前谢谢!
P.S。作为一个副请求:任何有关资源的指针都可以更好地理解TensorFlow的输入管道API。
答案 0 :(得分:1)
如果您正在处理Example
而不是SequenceExample
,那么就像在解码的张量上添加对tf.train.shuffle_batch
的调用一样简单。
_, value = reader.read(queue)
input_, output = decode(value)
batch_input, batch_output = tf.train.shuffle_batch([input_, output],
batch_size=BATCH_SIZE, capacity=CAPACITY,
min_after_sequeue=MIN_AFTER_DEQUEUE)
然而,随机批量要求您传入的张量具有静态形状,这在此不正确。对于可变形状张量,您可以将tf.train.batch
与dynamic_pad=True
一起使用。这将为您处理批处理(和填充),但不会随机播放您的示例。不幸的是,shuffle_batch
没有采用dynamic_pad
参数。
有一个解决方法described here,您可以在调用RandomShuffleQueue
之前添加tf.train.batch
:
inputs = decode(value)
dtypes = list(map(lambda x: x.dtype, inputs))
shapes = list(map(lambda x: x.get_shape(), inputs))
queue = tf.RandomShuffleQueue(CAPACITY, MIN_AFTER_DEQUEUE, dtypes)
enqueue_op = queue.enqueue(inputs)
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS)
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr)
inputs = queue.dequeue()
for tensor, shape in zip(inputs, shapes):
tensor.set_shape(shape)
# Now you can use tf.train.batch with dynamic_pad=True, and the order in which
# it enqueues elements will be permuted because of RandomShuffleQueue.
batch_input, batch_output = tf.train.batch(inputs, batch_size, capacity=capacity,
dynamic_pad=True, name=name)
此实施的模式示例here(在Google的Magenta项目中)。