我们正在运行以下tensorflow代码,问题在于内存使用量一直在增加,大约在第30个纪元(一个多小时)内,它耗尽了内存并停止运行。有什么建议吗?
我尝试在每次迭代结束时使用gc.collect();试图在Spyder的变量浏览器中看到任何变量都超出范围,但在整个程序执行期间它们的大小保持不变。
if __name__ == '__main__':
data_dir = "./datasets/monet2photo/"
batch_size = 1
epochs = 500
mode = 'train'
with tf.Session() as sess:
if mode == 'train':
imagesA, imagesB = load_images(data_dir)
image_size = 128
input_c_dim = 3
output_c_dim = 3
pool = ImagePool(max_size)
real_data = tf.placeholder(tf.float32,
[None, image_size, image_size, input_c_dim + output_c_dim],
name='real_A_and_B_images')
real_A = real_data[:, :, :, :input_c_dim]
real_B = real_data[:, :, :, input_c_dim:input_c_dim + output_c_dim]
fake_B = generator(real_A, options, False, name="generatorA2B")
fake_A_ = generator(fake_B, options, False, name="generatorB2A")
fake_A = generator(real_B, options, True, name="generatorB2A")
fake_B_ = generator(fake_A, options, True, name="generatorA2B")
DB_fake = discriminator(fake_B, options, reuse=False, name="discriminatorB")
DA_fake = discriminator(fake_A, options, reuse=False, name="discriminatorA")
g_loss_a2b = criterionGAN(DB_fake, tf.ones_like(DB_fake)) \
+ L1_lambda * abs_criterion(real_A, fake_A_) \
+ L1_lambda * abs_criterion(real_B, fake_B_) \
+ 0.5 * abs_criterion(generator(real_B, options, True, name="generatorA2B"), real_B)
g_loss_b2a = criterionGAN(DA_fake, tf.ones_like(DA_fake)) \
+ L1_lambda * abs_criterion(real_A, fake_A_) \
+ L1_lambda * abs_criterion(real_B, fake_B_) \
+ 0.5 * abs_criterion(generator(real_A, options, True, name="generatorB2A"), real_A)
g_loss = criterionGAN(DA_fake, tf.ones_like(DA_fake)) \
+ criterionGAN(DB_fake, tf.ones_like(DB_fake)) \
+ L1_lambda * abs_criterion(real_A, fake_A_) \
+ L1_lambda * abs_criterion(real_B, fake_B_) \
+ 0.5 * abs_criterion(generator(real_B, options, True, name="generatorA2B"), real_B) \
+ 0.5 * abs_criterion(generator(real_A, options, True, name="generatorB2A"), real_A)
# generatorBId and then generatorAId
fake_A_sample = tf.placeholder(tf.float32,
[None, image_size, image_size,
input_c_dim], name='fake_A_sample')
fake_B_sample = tf.placeholder(tf.float32,
[None, image_size, image_size,
output_c_dim], name='fake_B_sample')
DB_real = discriminator(real_B, options, reuse=True, name="discriminatorB")
DA_real = discriminator(real_A, options, reuse=True, name="discriminatorA")
DB_fake_sample = discriminator(fake_B_sample, options, reuse=True, name="discriminatorB")
DA_fake_sample = discriminator(fake_A_sample, options, reuse=True, name="discriminatorA")
db_loss_real = criterionGAN(DB_real, tf.ones_like(DB_real))
db_loss_fake = criterionGAN(DB_fake_sample, tf.zeros_like(DB_fake_sample))
db_loss = (db_loss_real + db_loss_fake) / 2
da_loss_real = criterionGAN(DA_real, tf.ones_like(DA_real))
da_loss_fake = criterionGAN(DA_fake_sample, tf.zeros_like(DA_fake_sample))
da_loss = (da_loss_real + da_loss_fake) / 2
d_loss = da_loss + db_loss
model_vars = tf.trainable_variables()
d_A_vars = [var for var in model_vars if 'discriminatorA' in var.name]
g_A_vars = [var for var in model_vars if 'generatorA2B' in var.name]
d_B_vars = [var for var in model_vars if 'discriminatorB' in var.name]
g_B_vars = [var for var in model_vars if 'generatorB2A' in var.name]
lr = tf.placeholder(tf.float32, None, name='learning_rate')
optimizer = tf.train.AdamOptimizer(lr, beta1=momentum)
d_A_trainer = optimizer.minimize(da_loss, var_list=d_A_vars)
d_B_trainer = optimizer.minimize(db_loss, var_list=d_B_vars)
g_A_trainer = optimizer.minimize(g_loss_a2b, var_list=g_A_vars)
g_B_trainer = optimizer.minimize(g_loss_b2a, var_list=g_B_vars)
init_op = tf.global_variables_initializer()
sess.run(init_op)
for epoch in range(epochs):
num_batches = int(min(imagesA.shape[0], imagesB.shape[0]) / batch_size)
lr_ = lr_initial if epoch < epoch_step else lr_initial*(epochs-epoch)/(epochs-epoch_step)
for index in range(num_batches):
batchA = imagesA[index * batch_size:(index + 1) * batch_size]
batchB = imagesB[index * batch_size:(index + 1) * batch_size]
batch_images = np.concatenate((batchA, batchB), axis=3)
batch_images = np.array(batch_images).astype(np.float32)
_, fake_B_t, gen_loss_a2b = sess.run(
[g_A_trainer, fake_B, g_loss_a2b],
feed_dict={real_data: batch_images, lr: lr_})
_, fake_A_t, gen_loss_b2a = sess.run(
[g_B_trainer, fake_A, g_loss_b2a],
feed_dict={real_data: batch_images, lr: lr_})
[fake_A_t, fake_B_t] = pool([fake_A_t, fake_B_t])
_, dis_loss_a = sess.run(
[d_A_trainer, da_loss],
feed_dict={real_data: batch_images,
fake_A_sample: fake_A_t,
fake_B_sample: fake_B_t,
lr: lr_})
_, dis_loss_b = sess.run(
[d_B_trainer, db_loss],
feed_dict={real_data: batch_images,
fake_A_sample: fake_A_t,
fake_B_sample: fake_B_t,
lr: lr_})