我有这个生成器函数,希望通过它筛选图像数据集。图像位于批处理大小为16的PyTorch数据加载器中。我在数据加载器上循环以获取一个批处理(16张图像),然后在该批次上循环以获取图像。
我想要做的是在图像上存储图像标签,然后yield
,同时在图形上绘制16张图像。所以我想做images = next(show_batch(dataloader, labels_dataframe, nrows, ncols))
,每次得到存储在图像中的16个图像标签和16张图像的图。这样,我可以识别出不良图像,并准备好将其标签从数据集中丢弃。该代码将两次生成相同(第一张)的16张图像。我怀疑这与每次创建新列表有关,所以我要重新启动生成器吗?
为什么代码会连续两次重复生成相同的16张图像,并且如何解决将其从labels_dataframe
获取的标签存储在图像中时一次生成16张图像的问题?
def show_batch(dataloader, labels_dataframe, nrows, ncols):
fig = plt.figure(figsize=(30,15))
for i, batch in enumerate(train_dl):
images = []
for j, image in enumerate(batch['image']):
ax = fig.add_subplot(nrows, ncols, j+1)
ax.imshow(image.permute(2, 1, 0))
images.append(labels_dataframe.loc[i*16+j, 'id_code'])
yield images
答案 0 :(得分:0)
Marco Bonelli指出,我应该说明如何调用生成器。在执行此操作时,我发现自己做错了什么,并修复了如何调用它和该函数。
我正在调用next(show_batch(dataloader, labels_dataframe, nrows, ncols))
,因此每次都调用一个新的生成器函数。我没有生成生成器对象。
然后,当我创建一个生成器对象并开始调用它时,它仅显示前16张图像,然后仅生成标签,因此我将Figure对象移动到每个批处理循环中。修改后的代码及其命名方式:
def show_batch(dataloader, labels_dataframe, nrows, ncols):
for i, batch in enumerate(train_dl):
fig = plt.figure(figsize=(30,15))
images = []
for j, image in enumerate(batch['image']):
ax = fig.add_subplot(nrows, ncols, j+1)
ax.imshow(image.permute(2, 1, 0))
images.append(labels_dataframe.loc[i*16+j, 'id_code'])
yield images
sample_images = show_batch(dataloader, labels_dataframe, nrows, ncols)
next(sample_images)
谢谢,马克。