tensorflow数据集shuffle示例而不是批处理

时间:2018-03-15 23:12:12

标签: tensorflow

如何在批处理模式下获取张量流数据集以对所有样本进行混洗?它只是改组批次。

下面是一个程序,它创建了1000个项目的数据集,并以5个批次经历了10个时期。我已打开shuffle()。我可以看到,tensorflow将数据集分成200个批次,每个5个例子,并且洗牌是跨越这些批次。我希望每个新批次都是原始1000个样本的随机样本,而不是200个原始批次的样本。

即,这个程序:

import numpy as np
import tensorflow as tf
import random


def rec2tfrec_example(rec):
    def _int64_feat(value):
        arr_value = np.empty([1], dtype=np.int64)
        arr_value[0] = value
        return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_value))

    feat = {
        'uid': _int64_feat(rec['uid']),
    }

    return tf.train.Example(features=tf.train.Features(feature=feat)).SerializeToString()


def parse_example(tfrec_serialized_string):
    feat = {
        'uid': tf.FixedLenFeature([], tf.int64),
    }
    return tf.parse_example(tfrec_serialized_string, feat)


def write_tfrecs_to_file(fname, recs):
        recwriter = tf.python_io.TFRecordWriter(fname)
        for rec in recs:
            recwriter.write(bytes(rec))
        recwriter.close()


def check_shuffle(sess, tfrec_output_filename, data, N, batch_size):
    epochs = 10
    dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
                     .batch(batch_size) \
                     .repeat(epochs) \
                     .shuffle(2*N) \
                     .map(parse_example, num_parallel_calls=2)
    tf_iter = dataset.make_initializable_iterator()
    get_next = tf_iter.get_next()

    sess.run(tf_iter.initializer)
    num_batches = N//batch_size
    for epoch in range(epochs ):
        for batch in range(N//batch_size):
            tfres = sess.run(get_next)
            print("epoch=%4d batch=%d uid=%s" % (epoch, batch, tfres['uid']))


def main(N=1000, batch_size=5, tfrec_output_filename='tfrec_testing.tfrecords'):
    tf.reset_default_graph()
    data = [{'uid': uid } for uid in range(N)]
    tfrec_strings = [rec2tfrec_example(rec) for rec in data]
    write_tfrecs_to_file(tfrec_output_filename, tfrec_strings)
    with tf.Session() as sess:
        check_shuffle(sess, tfrec_output_filename, data, N, batch_size)

if __name__ == '__main__':
    main()

产生如下输出:

epoch=   9 batch=186 uid=[685 686 687 688 689]
epoch=   9 batch=187 uid=[235 236 237 238 239]
epoch=   9 batch=188 uid=[520 521 522 523 524]
epoch=   9 batch=189 uid=[135 136 137 138 139]
epoch=   9 batch=190 uid=[95 96 97 98 99]
epoch=   9 batch=191 uid=[290 291 292 293 294]
epoch=   9 batch=192 uid=[230 231 232 233 234]
epoch=   9 batch=193 uid=[215 216 217 218 219]

1 个答案:

答案 0 :(得分:1)

啊,批处理和随机播放的顺序很重要,如果我设置了像

这样的数据集
dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
                 .shuffle(2*N) \
                 .batch(batch_size) \
                 .repeat(epochs) \
                 .map(parse_example, num_parallel_calls=2)

批次之前随机播放,然后就可以了。