在Tensorflow中跟踪CNN模型时,如何从目录中读取图像作为输入和输出?

时间:2017-03-01 14:03:53

标签: machine-learning tensorflow deconvolution

我想使用CNN解决去模糊任务,我有训练数据,这是png图像的目录和包含文件名的相应文本文件。

由于数据太大而无法通过一步添加到内存中,并且是否有任何API或某些方法可以让我可以将blury图像作为输入读取并将其真实性视为预期结果进行训练?

我花了很多时间来解决这个问题,但在阅读在线API介绍中的API后,我感到困惑。

1 个答案:

答案 0 :(得分:0)

方法并不那么困惑。 tensorflow提供TFrecords文件以充分利用内存。

def create_cord():

    writer = tf.python_io.TFRecordWriter("train.tfrecords")
    for index in xrange(66742):
        blur_file_name = get_file_name(index, True)
        orig_file_name = get_file_name(index, False)
        blur_image_path = cwd + blur_file_name
        orig_image_path = cwd + orig_file_name

        blur_image = Image.open(blur_image_path)
        orig_image = Image.open(orig_image_path)

        blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))
        orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH))

        blur_image_raw = blur_image.tobytes()
        orig_image_raw = orig_image.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
        "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])),
        'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw]))
    }))
    writer.write(example.SerializeToString())
    writer.close()

阅读数据集:

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                   features={
                                       'blur_image_raw':    tf.FixedLenFeature([], tf.string),
                                       'orig_image_raw': tf.FixedLenFeature([], tf.string),
                                   })

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    blur_img = tf.cast(blur_img, tf.float32) * (1. / 255) - 0.5

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8)
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3])
    orig_img = tf.cast(orig_img, tf.float32) * (1. / 255) - 0.5

    return blur_img, orig_img


if __name__ == '__main__':

    #  create_cord()

    blur, orig = read_and_decode("train.tfrecords")
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig],
                                                batch_size=3, capacity=1000,
                                                min_after_dequeue=100)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
     # 启动队列
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(3):
            v, l = sess.run([blur_batch, orig_batch])
            print(v.shape, l.shape)