TensorFlow数据集将平铺图像添加到批次尺寸

时间:2019-10-18 20:16:46

标签: python tensorflow tensorflow-datasets

如果我有一个创建为图块的图像数据集,将图块尺寸与批处理尺寸结合起来的最佳方法是什么?

例如,我的输入文件的形状(300,300,3)是具有300x300像素的典型RGB图像。

我进行预处理并创建一个创建新形状的图块数据集:(?,100,128,128,3) 因此,我从原始图像中创建了100个尺寸为30x30的图块,并将每个图块调整为128x128像素,然后缓存了数据集并创建了一个尺寸为?的批处理。

现在我想将图块合并到批处理维度中,并得到以下形状:(?,128,128,3)

我尝试将数据集映射到此函数:

def reshape_image(image_batch):

    return tf.reshape(image_batch, (-1,128,128,3))

但这似乎不起作用,因为它导致迭代器挂起此调用:

image_test = next(iter(image_ds))

1 个答案:

答案 0 :(得分:0)

我认为,如果您熟悉Tensorflow操作,答案就很简单,希望这个问题不会太令人困惑,并且可以帮助某个人。

#load/preprocess images from paths
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
#split images into tiles so (X,Y,C) -> (N,X,Y,C) where N is the number of tiles
image_ds = image_ds.map(split_image, num_parallel_calls=AUTOTUNE)
#resize tiled images from 30x30 to 128x128, implementation doesn't really matter
image_ds = image_ds.map(resize_image, num_parallel_calls=AUTOTUNE)

#finally the answer!! use 'flat_map', 'unstack', and 'from_tensor_slices'
#tiled_images is of shape (N,X,Y,C)

def flat_map_impl(tiled_images):

  #You return a new Dataset
  #Unstack by default creates a list of tensors based on the first dimension
  #therefore tf.unstack(tiled_images) is a list of size N with (X,Y,C) shaped elements
  #finally from_tensor_slices creates a new dataset where each element is of shape (X,Y,C)
  return tf.data.Dataset.from_tensor_slices(tf.unstack(tiled_images))

#call flat_map_impl with flat_map on the dataset
image_ds = image_ds.flat_map(flat_map_impl)