我正在研究医学图像分类问题,并遇到数据集不足的问题。因此,要使用WGAN生成图像。在给定的代码中,WGAN代码示例采用了MNIST数据集。生成图像后,很容易识别出它们也属于哪个类别。但是在医学图像的情况下,生成图像后很难确定生成的图像属于哪一类,因为它们从以下给出的代码中成组保存:
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
那么我要执行哪些更改才能获取生成图像的标签。
答案 0 :(得分:1)
原始版本的WGAN无法有条件地生成图像。因此,您经过培训的WGAN只能生成图像,而无需知道它们属于哪个类。
要能够生成特定标签的图像,请检查条件性甘子。 Here's入门的中级文章。
替代选项是从原始训练数据中训练鉴别器,并使用该鉴别器来帮助您手动对图像进行分类。