在AWS Volta GPU中获取数据时内存用尽

时间:2018-07-23 05:54:52

标签: python tensorflow out-of-memory

我正在尝试复制deep video portrait(2018),并且已经包含了引用pix2pix(2017)的模型。我正在借用V100 Volta 16GB GPU from the AWS并运行学习课程,但是在导入训练数据的过程中,经常会出现内存不足的情况,因此无法进入训练阶段。 (AWS不断死亡)

以下是我对parse_function的构造,该构造从数据集中构造了[16(batch), 256, 256, 99(11 frames with 3 RGB images)]

   def parse_function(path_set, target_set):
        conditioning_input_list = []

        for i in range(0, path_set.shape[0]):
            frame_condition_paths = path_set[i]

            image_string = tf.read_file(frame_condition_paths[0])
            image_decoded = tf.image.decode_jpeg(image_string, channels=3)
            image_resized = tf.image.resize_images(image_decoded, [height, width]) / 127.5 - 1.

            correspondence_string = tf.read_file(frame_condition_paths[1])
            correspondence_decoded = tf.image.decode_jpeg(correspondence_string, channels=3)
            correspondence_resized = tf.image.resize_images(correspondence_decoded, [height, width]) / 127.5 - 1.

            eyemap_string = tf.read_file(frame_condition_paths[2])
            eyemap_decoded = tf.image.decode_jpeg(eyemap_string, channels=3)
            eyemap_resized = tf.image.resize_images(eyemap_decoded, [height, width]) / 127.5 - 1.

            stack_resized = tf.concat([image_resized, correspondence_resized, eyemap_resized], 2)
            conditioning_input_list.append(stack_resized)

        conditioning_input = tf.concat(conditioning_input_list , 2)

        target_string = tf.read_file(target_set)
        target_decoded = tf.image.decode_jpeg(target_string, channels=3)
        target_resized = tf.image.resize_images(target_decoded, [height, width]) / 127.5 - 1.

        return conditioning_input, target_resized


g = tf.Graph()
with g.as_default():

    train_sets = tf.placeholder(tf.string, shape=(batch_size*train_index , 11, 3), name='train_sets')
    target_sets = tf.placeholder(tf.string, shape=(batch_size*train_index), name='target_sets')

    train_dataset = Dataset.from_tensor_slices((train_sets, target_sets)).map(parse_function)
    train_dataset = train_dataset.shuffle(shuffle_buffer_size).batch(batch_size).repeat(train_epoch)

    training = tf.placeholder(tf.bool)
    handle = tf.placeholder(tf.string, shape=[])

    iterator = Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
    next_element = iterator.get_next()

    train_iterator = train_dataset.make_initializable_iterator()

    (input_images, target_image) = iterator.get_next()

    input_images.set_shape([batch_size, height, width, 99])
    target_image.set_shape([batch_size, height, width, 3])


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    sess.run(train_iterator.initializer, feed_dict={
        train_sets : train_paths[:batch_size*train_index],
        target_sets : target_paths[:batch_size*train_index]
    })
    train_handle = sess.run(train_iterator.string_handle())

    dl_sum = 0
    gl_sum = 0
    print('feed dict')
    for i in range(train_index * train_epoch) :
        run = [input_images, 
               target_image]
        out = sess.run(run, feed_dict={handle: train_handle, training:True})
        print(out.shape)

0 个答案:

没有答案