生成器函数总是两次产生相同的结果

时间:2019-07-25 21:23:37

标签: python generator yield

我有这个生成器函数,希望通过它筛选图像数据集。图像位于批处理大小为16的PyTorch数据加载器中。我在数据加载器上循环以获取一个批处理(16张图像),然后在该批次上循环以获取图像。

我想要做的是在图像上存储图像标签,然后yield,同时在图形上绘制16张图像。所以我想做images = next(show_batch(dataloader, labels_dataframe, nrows, ncols)),每次得到存储在图像中的16个图像标签和16张图像的图。这样,我可以识别出不良图像,并准备好将其标签从数据集中丢弃。该代码将两次生成相同(第一张)的16张图像。我怀疑这与每次创建新列表有关,所以我要重新启动生成器吗?

为什么代码会连续两次重复生成相同的16张图像,并且如何解决将其从labels_dataframe获取的标签存储在图像中时一次生成16张图像的问题?

def show_batch(dataloader, labels_dataframe, nrows, ncols):
    fig = plt.figure(figsize=(30,15))
    for i, batch in enumerate(train_dl):
        images = []
        for j, image in enumerate(batch['image']):
            ax = fig.add_subplot(nrows, ncols, j+1)
            ax.imshow(image.permute(2, 1, 0))
            images.append(labels_dataframe.loc[i*16+j, 'id_code'])
        yield images

1 个答案:

答案 0 :(得分:0)

Marco Bonelli指出,我应该说明如何调用生成器。在执行此操作时,我发现自己做错了什么,并修复了如何调用它和该函数。

我正在调用next(show_batch(dataloader, labels_dataframe, nrows, ncols)),因此每次都调用一个新的生成器函数。我没有生成生成器对象。

然后,当我创建一个生成器对象并开始调用它时,它仅显示前16张图像,然后仅生成标签,因此我将Figure对象移动到每个批处理循环中。修改后的代码及其命名方式:

def show_batch(dataloader, labels_dataframe, nrows, ncols):
    for i, batch in enumerate(train_dl):
        fig = plt.figure(figsize=(30,15))
        images = []
        for j, image in enumerate(batch['image']):
            ax = fig.add_subplot(nrows, ncols, j+1)
            ax.imshow(image.permute(2, 1, 0))
            images.append(labels_dataframe.loc[i*16+j, 'id_code'])
        yield images

sample_images = show_batch(dataloader, labels_dataframe, nrows, ncols)
next(sample_images)

谢谢,马克。