如何从不同大小的图像中提取相同大小的补丁并将它们与tensorflow数据集api一起批处理?

时间:2018-07-08 09:40:57

标签: python tensorflow machine-learning deep-learning computer-vision

我正在尝试为一组大小不同的图像制作一个tensorflow数据集api(tf版本1.8)。为此,我要从图像中提取相同大小的补丁并将其提供给我的神经网络。

问题出在tf.extract_patches_from_images中,补丁存储在通道维度中。由于每个图像的大小不同,因此每个图像的色标数量也不同。因此,每个结果图像的形状是不同的。因此,我无法使用tf数据集api将它们批处理在一起。

有人可以建议对我下面的Modify_image函数进行更改以解决此问题吗?  我猜想将补丁分成不同的图像,然后将它们批处理在一起即可。但是我不知道该怎么做。

我想扫描整个图像,因此,随机选择相等数量的补丁对我不起作用。

def modify_image(image):
'''add preprocessing functions here'''
    image = tf.expand_dims(image,0)
    image = tf.extract_image_patches(
        image,
        ksizes=[1,patch_size,patch_size,1],
        strides=[1,patch_size,patch_size,1],
        rates=[1,1,1,1],
        padding='SAME',
        name=None
    )
    image = tf.reshape(image,shape=[-1,patch_size,patch_size,1])

return image;

def parse_function(image,labels):
    image= tf.read_file(image)
    image = tf.image.decode_image(image)
    labels = tf.read_file(labels)
    labels = tf.image.decode_image(labels)
    image = modify_image(image)
    labels = modify_image(labels)
    return image,labels


def list_files(directory):
    files = glob.glob(directory)
    return files

def load_dataset(img_dir,labels_dir):
    images = list_files(img_dir)
    images = tf.constant(images)
    labels = list_files(labels_dir)
    labels = tf.constant(labels)

    dataset = tf.data.Dataset.from_tensor_slices((images,labels))
    dataset = dataset.map(parse_function)
    return dataset




def make_batches(home_dir,img_dir,labels_dir,batch_size):

    img_dir = home_dir + img_dir
    labels_dir = home_dir +labels_dir

    dataset = load_dataset(img_dir,labels_dir)
    batched_dataset = dataset.batch(batch_size)
    return batched_dataset  

1 个答案:

答案 0 :(得分:1)

tf.contrib.data.unbatch()转换在这里可能会有所帮助,因为它可以将补丁从单个图像分离为不同的元素:

dataset = tf.data.Dataset.from_tensor_slices((images,labels))
dataset = dataset.map(parse_function)
patches_dataset = dataset.apply(tf.contrib.data.unbatch())
batched_dataset = dataset.batch(batch_size)

请注意,要使tf.contrib.data.unbatch()工作,图像中的补丁数量必须与labels中的元素/行数量相匹配。例如,如果每个补丁都应获得相同的标签,则可以通过如下方式将tf.tile()的标签parse_function()修改为{{3}}来进行适当次数的修改:

def parse_function(images, labels):
  # ...
  return image, tf.tile([labels], tf.shape(image)[0:1])