我正在尝试复制deep video portrait(2018),并且已经包含了引用pix2pix(2017)的模型。我正在借用V100 Volta 16GB GPU from the AWS
并运行学习课程,但是在导入训练数据的过程中,经常会出现内存不足的情况,因此无法进入训练阶段。 (AWS不断死亡)
以下是我对parse_function的构造,该构造从数据集中构造了[16(batch), 256, 256, 99(11 frames with 3 RGB images)]
。
def parse_function(path_set, target_set):
conditioning_input_list = []
for i in range(0, path_set.shape[0]):
frame_condition_paths = path_set[i]
image_string = tf.read_file(frame_condition_paths[0])
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded, [height, width]) / 127.5 - 1.
correspondence_string = tf.read_file(frame_condition_paths[1])
correspondence_decoded = tf.image.decode_jpeg(correspondence_string, channels=3)
correspondence_resized = tf.image.resize_images(correspondence_decoded, [height, width]) / 127.5 - 1.
eyemap_string = tf.read_file(frame_condition_paths[2])
eyemap_decoded = tf.image.decode_jpeg(eyemap_string, channels=3)
eyemap_resized = tf.image.resize_images(eyemap_decoded, [height, width]) / 127.5 - 1.
stack_resized = tf.concat([image_resized, correspondence_resized, eyemap_resized], 2)
conditioning_input_list.append(stack_resized)
conditioning_input = tf.concat(conditioning_input_list , 2)
target_string = tf.read_file(target_set)
target_decoded = tf.image.decode_jpeg(target_string, channels=3)
target_resized = tf.image.resize_images(target_decoded, [height, width]) / 127.5 - 1.
return conditioning_input, target_resized
g = tf.Graph()
with g.as_default():
train_sets = tf.placeholder(tf.string, shape=(batch_size*train_index , 11, 3), name='train_sets')
target_sets = tf.placeholder(tf.string, shape=(batch_size*train_index), name='target_sets')
train_dataset = Dataset.from_tensor_slices((train_sets, target_sets)).map(parse_function)
train_dataset = train_dataset.shuffle(shuffle_buffer_size).batch(batch_size).repeat(train_epoch)
training = tf.placeholder(tf.bool)
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
next_element = iterator.get_next()
train_iterator = train_dataset.make_initializable_iterator()
(input_images, target_image) = iterator.get_next()
input_images.set_shape([batch_size, height, width, 99])
target_image.set_shape([batch_size, height, width, 3])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(train_iterator.initializer, feed_dict={
train_sets : train_paths[:batch_size*train_index],
target_sets : target_paths[:batch_size*train_index]
})
train_handle = sess.run(train_iterator.string_handle())
dl_sum = 0
gl_sum = 0
print('feed dict')
for i in range(train_index * train_epoch) :
run = [input_images,
target_image]
out = sess.run(run, feed_dict={handle: train_handle, training:True})
print(out.shape)