Tensorflow应用程序不断增加内存使用量

时间:2019-07-12 12:16:07

标签: python tensorflow memory memory-leaks

我们正在运行以下tensorflow代码,问题在于内存使用量一直在增加,大约在第30个纪元(一个多小时)内,它耗尽了内存并停止运行。有什么建议吗?

我尝试在每次迭代结束时使用gc.collect();试图在Spyder的变量浏览器中看到任何变量都超出范围,但在整个程序执行期间它们的大小保持不变。

if __name__ == '__main__':
    data_dir = "./datasets/monet2photo/"
    batch_size = 1
    epochs = 500
    mode = 'train'

    with tf.Session() as sess:

        if mode == 'train':
            imagesA, imagesB = load_images(data_dir)

            image_size = 128
            input_c_dim = 3
            output_c_dim = 3

            pool = ImagePool(max_size)

            real_data = tf.placeholder(tf.float32,
                                       [None, image_size, image_size, input_c_dim + output_c_dim],
                                       name='real_A_and_B_images')

            real_A = real_data[:, :, :, :input_c_dim]
            real_B = real_data[:, :, :, input_c_dim:input_c_dim + output_c_dim]

            fake_B = generator(real_A, options, False, name="generatorA2B")
            fake_A_ = generator(fake_B, options, False, name="generatorB2A")
            fake_A = generator(real_B, options, True, name="generatorB2A")
            fake_B_ = generator(fake_A, options, True, name="generatorA2B")

            DB_fake = discriminator(fake_B, options, reuse=False, name="discriminatorB")
            DA_fake = discriminator(fake_A, options, reuse=False, name="discriminatorA")

            g_loss_a2b = criterionGAN(DB_fake, tf.ones_like(DB_fake)) \
                + L1_lambda * abs_criterion(real_A, fake_A_) \
                + L1_lambda * abs_criterion(real_B, fake_B_) \
                + 0.5 * abs_criterion(generator(real_B, options, True, name="generatorA2B"), real_B)
            g_loss_b2a = criterionGAN(DA_fake, tf.ones_like(DA_fake)) \
                + L1_lambda * abs_criterion(real_A, fake_A_) \
                + L1_lambda * abs_criterion(real_B, fake_B_) \
                + 0.5 * abs_criterion(generator(real_A, options, True, name="generatorB2A"), real_A)
            g_loss = criterionGAN(DA_fake, tf.ones_like(DA_fake)) \
                + criterionGAN(DB_fake, tf.ones_like(DB_fake)) \
                + L1_lambda * abs_criterion(real_A, fake_A_) \
                + L1_lambda * abs_criterion(real_B, fake_B_) \
                + 0.5 * abs_criterion(generator(real_B, options, True, name="generatorA2B"), real_B) \
                + 0.5 * abs_criterion(generator(real_A, options, True, name="generatorB2A"), real_A)
             # generatorBId and then generatorAId

            fake_A_sample = tf.placeholder(tf.float32,
                                           [None, image_size, image_size,
                                           input_c_dim], name='fake_A_sample')
            fake_B_sample = tf.placeholder(tf.float32,
                                           [None, image_size, image_size,
                                           output_c_dim], name='fake_B_sample')

            DB_real = discriminator(real_B, options, reuse=True, name="discriminatorB")
            DA_real = discriminator(real_A, options, reuse=True, name="discriminatorA")
            DB_fake_sample = discriminator(fake_B_sample, options, reuse=True, name="discriminatorB")
            DA_fake_sample = discriminator(fake_A_sample, options, reuse=True, name="discriminatorA")

            db_loss_real = criterionGAN(DB_real, tf.ones_like(DB_real))
            db_loss_fake = criterionGAN(DB_fake_sample, tf.zeros_like(DB_fake_sample))
            db_loss = (db_loss_real + db_loss_fake) / 2
            da_loss_real = criterionGAN(DA_real, tf.ones_like(DA_real))
            da_loss_fake = criterionGAN(DA_fake_sample, tf.zeros_like(DA_fake_sample))
            da_loss = (da_loss_real + da_loss_fake) / 2
            d_loss = da_loss + db_loss

            model_vars = tf.trainable_variables()
            d_A_vars = [var for var in model_vars if 'discriminatorA' in var.name]
            g_A_vars = [var for var in model_vars if 'generatorA2B' in var.name]
            d_B_vars = [var for var in model_vars if 'discriminatorB' in var.name]
            g_B_vars = [var for var in model_vars if 'generatorB2A' in var.name]

            lr = tf.placeholder(tf.float32, None, name='learning_rate')
            optimizer = tf.train.AdamOptimizer(lr, beta1=momentum)

            d_A_trainer = optimizer.minimize(da_loss, var_list=d_A_vars)
            d_B_trainer = optimizer.minimize(db_loss, var_list=d_B_vars)
            g_A_trainer = optimizer.minimize(g_loss_a2b, var_list=g_A_vars)
            g_B_trainer = optimizer.minimize(g_loss_b2a, var_list=g_B_vars)

            init_op = tf.global_variables_initializer()
            sess.run(init_op)

            for epoch in range(epochs):
                num_batches = int(min(imagesA.shape[0], imagesB.shape[0]) / batch_size)
                lr_ = lr_initial if epoch < epoch_step else lr_initial*(epochs-epoch)/(epochs-epoch_step)

                for index in range(num_batches):
                    batchA = imagesA[index * batch_size:(index + 1) * batch_size]
                    batchB = imagesB[index * batch_size:(index + 1) * batch_size]
                    batch_images = np.concatenate((batchA, batchB), axis=3)
                    batch_images = np.array(batch_images).astype(np.float32)

                    _, fake_B_t, gen_loss_a2b = sess.run(
                        [g_A_trainer, fake_B, g_loss_a2b],
                        feed_dict={real_data: batch_images, lr: lr_})
                    _, fake_A_t, gen_loss_b2a = sess.run(
                        [g_B_trainer, fake_A, g_loss_b2a],
                        feed_dict={real_data: batch_images, lr: lr_})
                    [fake_A_t, fake_B_t] = pool([fake_A_t, fake_B_t])

                    _, dis_loss_a = sess.run(
                        [d_A_trainer, da_loss],
                        feed_dict={real_data: batch_images,
                                   fake_A_sample: fake_A_t,
                                   fake_B_sample: fake_B_t,
                                   lr: lr_})    
                    _, dis_loss_b = sess.run(
                        [d_B_trainer, db_loss],
                        feed_dict={real_data: batch_images,
                                   fake_A_sample: fake_A_t,
                                   fake_B_sample: fake_B_t,
                                   lr: lr_})

0 个答案:

没有答案