我在基于张量流的网络培训中使用数据生成器中的itetools
。主要框架如下:
def data_generator(filenames, batch_size):
files= itertools.cycle(filenames)
while True:
X = []
Y = []
for _ in range(batch_size):
filename = files.next()
# read data into X and Y
....
yield np.array(X), np.array(Y)
使用此数据生成器时,
train_input = data_generator(train_filenames,batch_size=1)
for ep in range(num_epochs):
for _ in range(num_train_samples):
image_batch, label_batch = train_input.next()
loss_val = sess.run([loss_op], feed_dict={})
我的问题是:一般来说,我们需要在每个时代之后改组训练数据,在这种情况下如何改变?感谢。
答案 0 :(得分:2)
创建后不能修改cycle
迭代器,因此必须在每个" epoch"中创建一个新的迭代器。循环:
def data_generator(filenames, batch_size):
filenames = filenames[:] # make a copy
random.shuffle(filenames)
files = itertools.cycle(filenames)
...
for ep in range(num_epochs):
train_input = data_generator(train_filenames, batch_size=1)
...