我使用GPU训练时的ResourceExhaustedError我的GAN

时间:2018-04-10 16:42:39

标签: tensorflow out-of-memory

当我尝试使用Tesla K80训练我的DCGAN时,遇到了ResourceExhaustedError。我已经尝试更改批量大小,但它没有工作。所以我认为我的代码可能存在一些问题。遗憾的是,我没有从我的代码中找到问题。“

我的代码:

import glob as gb

import numpy as np
from PIL import Image

import data_utils
from losses import adversarial_loss, generator_loss
from model import generator_model, discriminator_model, generator_containing_discriminator


def train(batch_size, epoch_num):
    # Note the x(blur) in the second, the y(full) in the first
    y_train, x_train = data_utils.load_data(data_type='train')

    # GAN
    g = generator_model()
    d = discriminator_model()
    d_on_g = generator_containing_discriminator(g, d)

    # compile the models, use default optimizer parameters
    # generator use adversarial loss
    g.compile(optimizer='adam', loss=generator_loss)
    # discriminator use binary cross entropy loss
    d.compile(optimizer='adam', loss='binary_crossentropy')
    # adversarial net use adversarial loss
    d_on_g.compile(optimizer='adam', loss=adversarial_loss)

    for epoch in range(epoch_num):
        print('epoch: ', epoch + 1, '/', epoch_num)
        print('batches: ', int(x_train.shape[0] / batch_size))

        for index in range(int(x_train.shape[0] / batch_size)):
            # select a batch data
            image_blur_batch = x_train[index * batch_size:(index + 1) * batch_size]
            image_full_batch = y_train[index * batch_size:(index + 1) * batch_size]
            generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)

            # output generated images for each 30 iters
            if (index % 30 == 0) and (index != 0):
                data_utils.generate_image(image_full_batch, image_blur_batch, generated_images,
                                          'result/interim/', epoch, index)

            # concatenate the full and generated images,
            # the full images at top, the generated images at bottom
            x = np.concatenate((image_full_batch, generated_images))

            # generate labels for the full and generated images
            y = [1] * batch_size + [0] * batch_size

            # train discriminator
            d_loss = d.train_on_batch(x, y)
            print('batch %d d_loss : %f' % (index + 1, d_loss))

            # let discriminator can't be trained
            d.trainable = False

            # train adversarial net
            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [1] * batch_size)
            print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss))

            # train generator
            g_loss = g.train_on_batch(image_blur_batch, image_full_batch)
            print('batch %d g_loss : %f' % (index + 1, g_loss))

            # let discriminator can be trained
            d.trainable = True

            # output weights for generator and discriminator each 30 iters
            if (index % 30 == 0) and (index != 0):
                g.save_weights('weight/generator_weights.h5', True)
                d.save_weights('weight/discriminator_weights.h5', True)


def test(batch_size):
    # Note the x(blur) in the second, the y(full) in the first
    y_test, x_test = data_utils.load_data(data_type='test')
    g = generator_model()
    g.load_weights('weight/generator_weights.h5')
    generated_images = g.predict(x=x_test, batch_size=batch_size)
    data_utils.generate_image(y_test, x_test, generated_images, 'result/finally/')


def test_pictures(batch_size):
    data_path = 'data/test/*.jpeg'
    images_path = gb.glob(data_path)
    data_blur = []
    for image_path in images_path:
        image_blur = Image.open(image_path)
        data_blur.append(np.array(image_blur))

    data_blur = np.array(data_blur).astype(np.float32)
    data_blur = data_utils.normalization(data_blur)

    g = generator_model()
    g.load_weights('weight/generator_weights.h5')
    generated_images = g.predict(x=data_blur, batch_size=batch_size)
    generated = generated_images * 127.5 + 127.5
    for i in range(generated.shape[0]):
        image_generated = generated[i, :, :, :]
        Image.fromarray(image_generated.astype(np.uint8)).save('result/test/' + str(i) + '.png')


if __name__ == '__main__':
    train(batch_size=2, epoch_num=5)
    test(4)
    test_pictures(2)

错误是:

2018-04-10 09:13:00 PSToutput_tensors, _, _ = self.run_internal_graph(inputs, masks)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/keras/engine/topology.py", line 2212, in run_internal_graph
2018-04-10 09:13:00 PSToutput_tensors = _to_list(layer.call(computed_tensor, **kwargs))
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/keras/layers/advanced_activations.py", line 41, in call
2018-04-10 09:13:00 PSTreturn K.relu(inputs, alpha=self.alpha)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 2668, in relu
2018-04-10 09:13:00 PSTx -= alpha * negative_part
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 821, in binary_op_wrapper
2018-04-10 09:13:00 PSTreturn func(x, y, name=name)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 1044, in _mul_dispatch
2018-04-10 09:13:00 PSTreturn gen_math_ops._mul(x, y, name=name)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 1434, in _mul
2018-04-10 09:13:00 PSTresult = _op_def_lib.apply_op("Mul", x=x, y=y, name=name)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
2018-04-10 09:13:00 PSTop_def=op_def)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
2018-04-10 09:13:00 PSToriginal_op=self._default_original_op, op_def=op_def)
2018-04-10 09:13:00 PSTFile "/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
2018-04-10 09:13:00 PSTself._traceback = _extract_stack()
2018-04-10 09:13:00 PST
2018-04-10 09:13:00 PSTResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[2,256,256,640]

有人可以请你解释一下这个错误背后的原因,ResourceExhaustedError。是因为GPU的内存不足以加载数据集吗?但是当我尝试用CPU训练它时,它只用了一分钟就耗尽了7.7GB内存。

0 个答案:

没有答案