批量和洗牌填充tf.train.SequenceExample

时间:2017-04-12 10:59:51

标签: input tensorflow protocol-buffers padding pipeline

我有一些序列到序列场景的训练示例,它们以tf.train.SequenceExample形式存储在一个(或多个)文件TFRecordWriter中。我想阅读,解码它们并将混乱的批次送入我的网络。我一直在努力使用文档和一些在这里和那里找到的教程,但我无法用这些东西做任何事情。我正在研究一个自包含的例子,如下所示。

import random

import tensorflow as tf

from six.moves import xrange


MIN_LEN = 6
MAX_LEN = 12
NUM_EXAMPLES = 20
BATCH_SIZE = 3
PATH = 'ciaone.tfrecords'
MIN_AFTER_DEQUEUE = 10
NUM_THREADS = 2
SAFETY_MARGIN = 1
CAPACITY = MIN_AFTER_DEQUEUE + (NUM_THREADS + SAFETY_MARGIN) * BATCH_SIZE


def generate_example():
    # fake examples which are just useful to have a quick visualization.
    # The input is a sequence of random numbers.
    # The output is a sequence made of those numbers from the
    # input sequence which are greater or equal then the average.
    length = random.randint(MIN_LEN, MAX_LEN)
    input_ = [random.randint(0, 10) for _ in xrange(length)]
    avg = sum([1.0 * item for item in input_]) / len(input_)
    output = [item for item in input_ if item >= avg]
    return input_, output


def encode(input_, output):
    length = len(input_)
    example = tf.train.SequenceExample(
        context=tf.train.Features(
            feature={
                'length': tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[length]))
            }),
        feature_lists=tf.train.FeatureLists(
            feature_list={
                'input': tf.train.FeatureList(
                    feature=[
                        tf.train.Feature(
                            int64_list=tf.train.Int64List(value=[item]))
                        for item in input_]),
                'output': tf.train.FeatureList(
                    feature=[
                        tf.train.Feature(
                            int64_list=tf.train.Int64List(value=[item]))
                        for item in output])
            }
        )
    )
    return example


def decode(example):
    context_features = {
        'length': tf.FixedLenFeature([], tf.int64)
    }
    sequence_features = {
        'input': tf.FixedLenSequenceFeature([], tf.int64),
        'output': tf.FixedLenSequenceFeature([], tf.int64)
    }
    ctx, seq = tf.parse_single_sequence_example(
        example, context_features, sequence_features)
    input_ = seq['input']
    output = seq['output']
    return input_, output

if __name__ == '__main__':
    # STEP 1. -- generate a dataset.
    with tf.python_io.TFRecordWriter(PATH) as writer:
        for _ in xrange(NUM_EXAMPLES):
           record = encode(*generate_example())
           writer.write(record.SerializeToString())

    with tf.Session() as sess:
        queue = tf.train.string_input_producer([PATH])
        reader = tf.TFRecordReader()
        _, value = reader.read(queue)
        input_, output = decode(value)

        # HERE I AM STUCK!

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        try:
            while True:
                # do something...
        except tf.errors.OutOfRangeError, e:
            coord.request_stop(e)
        finally:
            coord.request_stop()
            coord.join(threads)
        coord.request_stop()
        coord.join(threads)

任何人都可以建议我如何进行? 提前谢谢!

P.S。作为一个副请求:任何有关资源的指针都可以更好地理解TensorFlow的输入管道API。

1 个答案:

答案 0 :(得分:1)

如果您正在处理Example而不是SequenceExample,那么就像在解码的张量上添加对tf.train.shuffle_batch的调用一样简单。

_, value = reader.read(queue)
input_, output = decode(value)
batch_input, batch_output = tf.train.shuffle_batch([input_, output],
    batch_size=BATCH_SIZE, capacity=CAPACITY,
    min_after_sequeue=MIN_AFTER_DEQUEUE)

然而,随机批量要求您传入的张量具有静态形状,这在此不正确。对于可变形状张量,您可以将tf.train.batchdynamic_pad=True一起使用。这将为您处理批处理(和填充),但不会随机播放您的示例。不幸的是,shuffle_batch没有采用dynamic_pad参数。

有一个解决方法described here,您可以在调用RandomShuffleQueue之前添加tf.train.batch

inputs = decode(value)
dtypes = list(map(lambda x: x.dtype, inputs))
shapes = list(map(lambda x: x.get_shape(), inputs))
queue = tf.RandomShuffleQueue(CAPACITY, MIN_AFTER_DEQUEUE, dtypes)
enqueue_op = queue.enqueue(inputs)
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS)
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr)
inputs = queue.dequeue()
for tensor, shape in zip(inputs, shapes):
    tensor.set_shape(shape)

# Now you can use tf.train.batch with dynamic_pad=True, and the order in which
# it enqueues elements will be permuted because of RandomShuffleQueue.
batch_input, batch_output = tf.train.batch(inputs, batch_size, capacity=capacity,
                              dynamic_pad=True, name=name)

此实施的模式示例here(在Google的Magenta项目中)。