我尝试制作GAN,但是keras计算出的损失是错误的,我也不知道为什么。鉴别器和生成器分几步返回0损失,鉴别器和生成器两者的精度均为1。鉴别器,生成器和对抗模型是角膜的顺序模型。而且我使用了keras的train_on_batch方法。
在此先感谢您的帮助。
这是我的代码:
class Pokemon_DCGAN(object):
def __init__(self):
self.img_rows = 64
self.img_cols = 64
self.channel = 3
self.x_train=self.directory_to_arr("pokemon-a")
self.DCGAN = DCGAN()
self.discriminator = self.DCGAN.discriminator_model()
self.adversarial = self.DCGAN.adversarial_model()
self.generator = self.DCGAN.generator()
def img_to_arr(self,fileName):
png=Image.open(fileName).convert("RGB")
png = png.resize((64, 64), Image.ANTIALIAS)
arr = np.asarray(png)
return arr
def arr_to_img(self,arr):
img = Image.fromarray(np.uint8(arr))
return img
def directory_to_arr(self,path_of_directory):
name_img_list=os.listdir(path_of_directory)
tab=np.asarray([self.img_to_arr(path_of_directory+"/"+name_img_list[0])])
for k in range(1,len(name_img_list)):
tab=np.concatenate((tab,[np.asarray(self.img_to_arr(path_of_directory+"/"+name_img_list[k]))]))
return tab
def random_pokemon(self):
noise = np.random.normal(0.0,1.0,size=[1,1,1,100])
images_fakes = self.generator.predict(noise)
print(images_fakes.shape)
img=array_to_img(images_fakes[0])
plt.imshow(img)
plt.show()
def train(self, train_steps=10000, batch_size=1, save_interval=1):#64 et 50
for i in range(train_steps):
self.discriminator.trainable=True
images_train = self.x_train[np.random.randint(0,self.x_train.shape[0],size=batch_size), :,:,:]
noise = np.random.normal(0,1,size=(batch_size,1,1, 100))
images_fake = self.generator.predict(noise)
x = np.concatenate((images_train,images_fake))
y = np.ones([2*batch_size,1])
y[batch_size:,:]=0.0
#it return incorect loss in d_loss
d_loss=self.discriminator.train_on_batch(x, y)
print(d_loss[0])
self.discriminator.trainable=False
y=np.ones([batch_size,1])
#it return incorect loss in a_loss
a_loss=self.adversarial.train_on_batch(noise,y)
if(i%save_interval ==0):
self.random_pokemon()
log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
print(log_mesg)
体系结构:
生成器:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_transpose_1 (Conv2DTr (None, 4, 4, 512) 819712
_________________________________________________________________
batch_normalization_1 (Batch (None, 4, 4, 512) 2048
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 4, 4, 512) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 8, 8, 256) 2097408
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 256) 1024
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 256) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 16, 16, 128) 524416
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 128) 512
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 32, 32, 64) 131136
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 64) 256
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
batch_normalization_5 (Batch (None, 32, 32, 64) 256
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 64, 64, 3) 3075
_________________________________________________________________
activation_1 (Activation) (None, 64, 64, 3) 0
=================================================================
Total params: 3,616,771
Trainable params: 3,614,723
Non-trainable params: 2,048
鉴别符:
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 32, 32, 64) 3136
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 16, 16, 128) 131200
_________________________________________________________________
batch_normalization_6 (Batch (None, 16, 16, 128) 512
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 16, 16, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 8, 8, 256) 524544
_________________________________________________________________
batch_normalization_7 (Batch (None, 8, 8, 256) 1024
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 8, 8, 256) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 4, 4, 512) 2097664
_________________________________________________________________
batch_normalization_8 (Batch (None, 4, 4, 512) 2048
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 4, 4, 512) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 8192) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 8193
_________________________________________________________________
activation_2 (Activation) (None, 1) 0
=================================================================
Total params: 2,768,321
Trainable params: 2,766,529
Non-trainable params: 1,792
首例损失:
0: [D loss: 0.963444, acc: 0.609375] [A loss: 0.296040, acc: 0.890625]
5: [D loss: 0.002083, acc: 1.000000] [A loss: 0.000003, acc: 1.000000]
10: [D loss: 0.004108, acc: 1.000000] [A loss: 0.000004, acc: 1.000000]
15: [D loss: 0.000665, acc: 1.000000] [A loss: 0.000003, acc: 1.000000]
20: [D loss: 0.000458, acc: 1.000000] [A loss: 0.000002, acc: 1.000000]