如何在itertools.cycle()中进行随机播放?

时间:2018-01-16 16:01:50

标签: python tensorflow itertools

我在基于张量流的网络培训中使用数据生成器中的itetools。主要框架如下:

def data_generator(filenames, batch_size):
    files= itertools.cycle(filenames)
    while True:
        X = []
        Y = []
        for _ in range(batch_size):
            filename = files.next()
            # read data into X and Y
            ....

        yield np.array(X), np.array(Y)

使用此数据生成器时,

train_input = data_generator(train_filenames,batch_size=1)
for ep in range(num_epochs):
    for _ in range(num_train_samples):
        image_batch, label_batch = train_input.next()
        loss_val = sess.run([loss_op], feed_dict={})

我的问题是:一般来说,我们需要在每个时代之后改组训练数据,在这种情况下如何改变?感谢。

1 个答案:

答案 0 :(得分:2)

创建后不能修改cycle迭代器,因此必须在每个" epoch"中创建一个新的迭代器。循环:

def data_generator(filenames, batch_size):
    filenames = filenames[:] # make a copy
    random.shuffle(filenames)
    files = itertools.cycle(filenames)
    ...

for ep in range(num_epochs):
    train_input = data_generator(train_filenames, batch_size=1)
    ...