如何在批处理模式下获取张量流数据集以对所有样本进行混洗?它只是改组批次。
下面是一个程序,它创建了1000个项目的数据集,并以5个批次经历了10个时期。我已打开shuffle()
。我可以看到,tensorflow将数据集分成200个批次,每个5个例子,并且洗牌是跨越这些批次。我希望每个新批次都是原始1000个样本的随机样本,而不是200个原始批次的样本。
即,这个程序:
import numpy as np
import tensorflow as tf
import random
def rec2tfrec_example(rec):
def _int64_feat(value):
arr_value = np.empty([1], dtype=np.int64)
arr_value[0] = value
return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_value))
feat = {
'uid': _int64_feat(rec['uid']),
}
return tf.train.Example(features=tf.train.Features(feature=feat)).SerializeToString()
def parse_example(tfrec_serialized_string):
feat = {
'uid': tf.FixedLenFeature([], tf.int64),
}
return tf.parse_example(tfrec_serialized_string, feat)
def write_tfrecs_to_file(fname, recs):
recwriter = tf.python_io.TFRecordWriter(fname)
for rec in recs:
recwriter.write(bytes(rec))
recwriter.close()
def check_shuffle(sess, tfrec_output_filename, data, N, batch_size):
epochs = 10
dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
.batch(batch_size) \
.repeat(epochs) \
.shuffle(2*N) \
.map(parse_example, num_parallel_calls=2)
tf_iter = dataset.make_initializable_iterator()
get_next = tf_iter.get_next()
sess.run(tf_iter.initializer)
num_batches = N//batch_size
for epoch in range(epochs ):
for batch in range(N//batch_size):
tfres = sess.run(get_next)
print("epoch=%4d batch=%d uid=%s" % (epoch, batch, tfres['uid']))
def main(N=1000, batch_size=5, tfrec_output_filename='tfrec_testing.tfrecords'):
tf.reset_default_graph()
data = [{'uid': uid } for uid in range(N)]
tfrec_strings = [rec2tfrec_example(rec) for rec in data]
write_tfrecs_to_file(tfrec_output_filename, tfrec_strings)
with tf.Session() as sess:
check_shuffle(sess, tfrec_output_filename, data, N, batch_size)
if __name__ == '__main__':
main()
产生如下输出:
epoch= 9 batch=186 uid=[685 686 687 688 689]
epoch= 9 batch=187 uid=[235 236 237 238 239]
epoch= 9 batch=188 uid=[520 521 522 523 524]
epoch= 9 batch=189 uid=[135 136 137 138 139]
epoch= 9 batch=190 uid=[95 96 97 98 99]
epoch= 9 batch=191 uid=[290 291 292 293 294]
epoch= 9 batch=192 uid=[230 231 232 233 234]
epoch= 9 batch=193 uid=[215 216 217 218 219]
答案 0 :(得分:1)
啊,批处理和随机播放的顺序很重要,如果我设置了像
这样的数据集dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
.shuffle(2*N) \
.batch(batch_size) \
.repeat(epochs) \
.map(parse_example, num_parallel_calls=2)
批次之前随机播放,然后就可以了。