当将tensorflow数据集API与tfrecord文件一起使用时,我想生成随机大小的裁剪。这是我的代码
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_func)
dataset = dataset.shuffle(buffer_size=10000).batch(batch_size).repeat()
iterator = dataset.make_initializable_iterator()
data_in_tsr, data_gt_tsr, = iterator.get_next()
def parse_func(example_proto):
features = {
'in': tf.FixedLenFeature(image_size, tf.float32),
'gt': tf.FixedLenFeature(image_size, tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
mag_in = parsed_features['mag_in']
mag_gt = parsed_features['mag_gt']
# random crop 128x128
crop_sz = random.randint(mag_in.shape[0]//2, mag_in.shape[0])
mag_in = tf.random_crop(mag_in, (crop_sz, crop_sz, mag_in.shape[2]))
return mag_in, mag_gt
问题在于map()函数仅被调用一次,因此每次都获得固定大小的随机裁剪。如何生成随机大小的作物?