在TensorFlow中读取`TFRecord`文件时,`tf.train.shuffle_batch`崩溃

时间:2017-06-08 13:27:55

标签: python tensorflow

我正在尝试使用tf.train.shuffle_batch使用TensorFlow 1.0从TFRecord文件中使用批量数据。相关职能是:

def tfrecord_to_graph_ops(filenames_list):
    file_queue = tf.train.string_input_producer(filenames_list)
    reader = tf.TFRecordReader()
    _, tfrecord = reader.read(file_queue)

    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={'targets': tf.FixedLenFeature([], tf.string)}
    )
    ## if no reshaping: `ValueError: All shapes must be fully defined` in
    ## `tf.train.shuffle_batch`
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    ## if using `strided_slice`, always get the first record
    # targets = tf.cast(
    #     tf.strided_slice(targets, [0], [1]),
    #     tf.int32
    # )
    ## error on shapes being fully defined
    # targets = tf.reshape(targets, [])
    ## get us: Invalid argument: Shape mismatch in tuple component 0.
    ## Expected [1], got [1000]
    targets.set_shape([1])
    return targets


def batch_generator(filenames_list, batch_size=BATCH_SIZE):
    targets = tfrecord_to_graph_ops(filenames_list)
    targets_batch = tf.train.shuffle_batch(
        [targets],
        batch_size=batch_size,
        capacity=(20 * batch_size),
        min_after_dequeue=(2 * batch_size)
    )
    targets_batch = tf.one_hot(
        indices=targets_batch, depth=10, on_value=1, off_value=0
    )
    return targets_batch


def examine_batches(targets_batch):
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(10):
            targets = sess.run([targets_batch])
            print(targets)
        coord.request_stop()
        coord.join(threads)

代码通过examine_batches()输入,已经传递batch_generator()的输出。我相信batch_generator()会调用tfrecord_to_graph_ops()问题就在那个函数中。

我正在打电话

targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)

在1000字节的文件(数字0-9)上。如果我在会话中调用eval(),它会向我显示所有1,000个元素。但是,如果我尝试将其放入批处理生成器中,它就会崩溃。

如果我不重新塑造targets,则在调用ValueError: All shapes must be fully defined时会出现tf.train.shuffle_batch之类的错误。如果我拨打targets.set_shape([1]),让人想起Google的CIFAR-10 example code,我会在Invalid argument: Shape mismatch in tuple component 0. Expected [1], got [1000]中收到tf.train.shuffle_batch之类的错误消息。我还尝试使用tf.strided_slice来剪切一大块原始数据 - 这不会崩溃,但会导致一次又一次地获取第一个事件。

这样做的正确方法是什么?从TFRecord文件中提取批次?

注意,我可以手动编写一个切断原始字节数据并执行某种批处理的函数 - 如果我使用feed_dict方法将数据添加到图表中,则会特别容易 - 但我正在尝试学习如何使用TensorFlow的TFRecord文件以及如何使用它们内置的批处理函数。

谢谢!

1 个答案:

答案 0 :(得分:1)

Allen Lavoie在评论中指出了正确的解决方案。重要的缺失部分是enqueue_many=True作为tf.train.shuffle_batch()的论据。编写这些函数的正确方法是:

def tfrecord_to_graph_ops(filenames_list):
    file_queue = tf.train.string_input_producer(filenames_list)
    reader = tf.TFRecordReader()
    _, tfrecord = reader.read(file_queue)

    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={'targets': tf.FixedLenFeature([], tf.string)}
    )
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    targets = tf.reshape(targets, [-1])
    return targets

def batch_generator(filenames_list, batch_size=BATCH_SIZE):
    targets = tfrecord_to_graph_ops(filenames_list)
    targets_batch = tf.train.shuffle_batch(
        [targets],
        batch_size=batch_size,
        capacity=(20 * batch_size),
        min_after_dequeue=(2 * batch_size),
        enqueue_many=True
    )
    return targets_batch