重用tfrecord文件中的多个示例的数据

时间:2018-03-07 14:54:00

标签: tensorflow tensorflow-datasets tfrecord

我尝试使用数据集api将tfrecords用于我的张量流输入管道。我的数据包括输入图像和参考图像。我的问题是多个输入图像具有相同的参考图像,因此我想以某种方式重用参考图像数据用于多个示例。有没有办法实现这个目标?

更新#1 输入管道:

def input_fn(filenames, batch_size, isTraining=True):

def _batch_parse_example(example_proto):
        features = {
            'input_image': tf.FixedLenFeature([], tf.string),
            'input_shape': tf.FixedLenFeature([], tf.string),
            'reference_image': tf.FixedLenFeature([], tf.string),
            'reference_shape': tf.FixedLenFeature([], tf.string)
        }
        parsed_features = tf.parse_single_example(example_proto, features, name='features')

        input_image = tf.decode_raw(parsed_features['input_image'], tf.float32)
        input_shape = tf.decode_raw(parsed_features['input_shape'], tf.int32)
        input_image = tf.reshape(input_image, [input_shape[0], input_shape[1], 3])

        reference_image = tf.decode_raw(parsed_features['reference_image'], tf.float32)
        reference_shape = tf.decode_raw(parsed_features['reference_shape'], tf.int32)
        reference_image = tf.reshape(reference_image, [reference_shape[0], reference_shape[1], 3])

        input_image = tf.expand_dims(input_image, 0)
        input_patches = tf.extract_image_patches(images=input_image, ksizes=[1, crop_size, crop_size, 1], strides=[1, crop_size, crop_size, 1], rates=[1, 1, 1, 1], padding='SAME')
        input_patches = tf.reshape(input_patches, [-1, crop_size, crop_size, 3])
        reference_image = tf.expand_dims(reference_image, 0)
        reference_patches = tf.extract_image_patches(images=reference_image, ksizes=[1, crop_size, crop_size, 1], strides=[1, crop_size, crop_size, 1], rates=[1, 1, 1, 1], padding='SAME')
        reference_patches = tf.reshape(reference_patches, [-1, crop_size, crop_size, 3])

        return tf.data.Dataset.from_tensor_slices((input_patches, reference_patches))

def _map(input_image, reference_image):
    return {'input_image': input_image}, reference_image

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(1000), cycle_length=8)
dataset = dataset.flat_map(_batch_parse_example)
dataset = dataset.map(_map)
dataset = dataset.shuffle(10000)
dataset = dataset.repeat(None if isTraining else 1)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(batch_size)

iterator = dataset.make_one_shot_iterator()
batch_input_image, batch_reference_image = iterator.get_next()
return batch_input_image, batch_reference_image

0 个答案:

没有答案