遍历喀拉斯邦的所有批次的生成器

时间:2019-11-14 06:13:00

标签: tensorflow keras conv-neural-network

我有一个由图像和标签组成的数据集,并加载了生成器,例如:

generator = image_generator.flow_from_directory(batch_size=BATCH_SIZE,
                                                           directory=val_dir,
                                                           shuffle=False,
                                                           target_size=(100,100),
                                                           class_mode='categorical')

使用CNN进行预测后,我想遍历所有结果并打印原始图像和预测的标签。

使用:

x,y = generator.next()

我设法做到了,但是我限于一批生成器中的元素数量。尝试打印更多的循环时会失去索引。

如何使用此方法遍历批次以获取所有结果?

2 个答案:

答案 0 :(得分:0)

您可以使用itertools.tee,它为您提供 n 个生成器的独立副本。 例如,您的情况

generator,gen_copy = itertools.tee(image_generator.flow_from_directory(
                                                  batch_size=BATCH_SIZE,
                                                  directory=val_dir,
                                                  shuffle=False,
                                                  target_size=(100,100),
                                                  class_mode='categorical',seed=0), 
                                                  n=2)

在这里,您将获得同一发生器的两个副本。一个可以传递给 fit_generator ,另一个可以根据您的目的进行迭代。请注意,在上面的代码块中添加了 seed = 0 ,以确保每个副本的混洗和转换顺序相同(如果您使用它们)。

现在 gen_copy 您可以像这样使用:

for x,y in gen_copy:
    print(x,y)

答案 1 :(得分:0)

official link中所述:

for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
    model.fit(x_batch, y_batch)
    batches += 1
    if batches >= len(x_train) / 32:
        # we need to break the loop by hand because
        # the generator loops indefinitely
        break

您可以模仿此示例代码来获取 datagen 中的每个批次。一个导入注意事项是,如果需要,您应该设置相同的种子以保持批次的顺序。