我尝试使用数据集api将tfrecords用于我的张量流输入管道。我的数据包括输入图像和参考图像。我的问题是多个输入图像具有相同的参考图像,因此我想以某种方式重用参考图像数据用于多个示例。有没有办法实现这个目标?
更新#1 输入管道:
def input_fn(filenames, batch_size, isTraining=True):
def _batch_parse_example(example_proto):
features = {
'input_image': tf.FixedLenFeature([], tf.string),
'input_shape': tf.FixedLenFeature([], tf.string),
'reference_image': tf.FixedLenFeature([], tf.string),
'reference_shape': tf.FixedLenFeature([], tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features, name='features')
input_image = tf.decode_raw(parsed_features['input_image'], tf.float32)
input_shape = tf.decode_raw(parsed_features['input_shape'], tf.int32)
input_image = tf.reshape(input_image, [input_shape[0], input_shape[1], 3])
reference_image = tf.decode_raw(parsed_features['reference_image'], tf.float32)
reference_shape = tf.decode_raw(parsed_features['reference_shape'], tf.int32)
reference_image = tf.reshape(reference_image, [reference_shape[0], reference_shape[1], 3])
input_image = tf.expand_dims(input_image, 0)
input_patches = tf.extract_image_patches(images=input_image, ksizes=[1, crop_size, crop_size, 1], strides=[1, crop_size, crop_size, 1], rates=[1, 1, 1, 1], padding='SAME')
input_patches = tf.reshape(input_patches, [-1, crop_size, crop_size, 3])
reference_image = tf.expand_dims(reference_image, 0)
reference_patches = tf.extract_image_patches(images=reference_image, ksizes=[1, crop_size, crop_size, 1], strides=[1, crop_size, crop_size, 1], rates=[1, 1, 1, 1], padding='SAME')
reference_patches = tf.reshape(reference_patches, [-1, crop_size, crop_size, 3])
return tf.data.Dataset.from_tensor_slices((input_patches, reference_patches))
def _map(input_image, reference_image):
return {'input_image': input_image}, reference_image
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(1000), cycle_length=8)
dataset = dataset.flat_map(_batch_parse_example)
dataset = dataset.map(_map)
dataset = dataset.shuffle(10000)
dataset = dataset.repeat(None if isTraining else 1)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(batch_size)
iterator = dataset.make_one_shot_iterator()
batch_input_image, batch_reference_image = iterator.get_next()
return batch_input_image, batch_reference_image