我想使用CNN解决去模糊任务,我有训练数据,这是png图像的目录和包含文件名的相应文本文件。
由于数据太大而无法通过一步添加到内存中,并且是否有任何API或某些方法可以让我可以将blury图像作为输入读取并将其真实性视为预期结果进行训练?
我花了很多时间来解决这个问题,但在阅读在线API介绍中的API后,我感到困惑。
答案 0 :(得分:0)
方法并不那么困惑。 tensorflow提供TFrecords文件以充分利用内存。
def create_cord():
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index in xrange(66742):
blur_file_name = get_file_name(index, True)
orig_file_name = get_file_name(index, False)
blur_image_path = cwd + blur_file_name
orig_image_path = cwd + orig_file_name
blur_image = Image.open(blur_image_path)
orig_image = Image.open(orig_image_path)
blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
blur_image_raw = blur_image.tobytes()
orig_image_raw = orig_image.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
阅读数据集:
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'blur_image_raw': tf.FixedLenFeature([], tf.string),
'orig_image_raw': tf.FixedLenFeature([], tf.string),
})
blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5
orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5
return blur_img, orig_img
if __name__ == '__main__':
# create_cord()
blur, orig = read_and_decode("train.tfrecords")
blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
batch_size=3, capacity=1000,
min_after_dequeue=100)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 启动队列
threads = tf.train.start_queue_runners(sess=sess)
for i in range(3):
v, l = sess.run([blur_batch, orig_batch])
print(v.shape, l.shape)