如果我有一个创建为图块的图像数据集,将图块尺寸与批处理尺寸结合起来的最佳方法是什么?
例如,我的输入文件的形状(300,300,3)是具有300x300像素的典型RGB图像。
我进行预处理并创建一个创建新形状的图块数据集:(?,100,128,128,3) 因此,我从原始图像中创建了100个尺寸为30x30的图块,并将每个图块调整为128x128像素,然后缓存了数据集并创建了一个尺寸为?的批处理。
现在我想将图块合并到批处理维度中,并得到以下形状:(?,128,128,3)
我尝试将数据集映射到此函数:
def reshape_image(image_batch):
return tf.reshape(image_batch, (-1,128,128,3))
但这似乎不起作用,因为它导致迭代器挂起此调用:
image_test = next(iter(image_ds))
答案 0 :(得分:0)
我认为,如果您熟悉Tensorflow操作,答案就很简单,希望这个问题不会太令人困惑,并且可以帮助某个人。
#load/preprocess images from paths
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
#split images into tiles so (X,Y,C) -> (N,X,Y,C) where N is the number of tiles
image_ds = image_ds.map(split_image, num_parallel_calls=AUTOTUNE)
#resize tiled images from 30x30 to 128x128, implementation doesn't really matter
image_ds = image_ds.map(resize_image, num_parallel_calls=AUTOTUNE)
#finally the answer!! use 'flat_map', 'unstack', and 'from_tensor_slices'
#tiled_images is of shape (N,X,Y,C)
def flat_map_impl(tiled_images):
#You return a new Dataset
#Unstack by default creates a list of tensors based on the first dimension
#therefore tf.unstack(tiled_images) is a list of size N with (X,Y,C) shaped elements
#finally from_tensor_slices creates a new dataset where each element is of shape (X,Y,C)
return tf.data.Dataset.from_tensor_slices(tf.unstack(tiled_images))
#call flat_map_impl with flat_map on the dataset
image_ds = image_ds.flat_map(flat_map_impl)