使用tf.extract_image_patches生成图像补丁以有效地生成一对图像

时间:2019-01-09 07:29:06

标签: tensorflow

我正在尝试为一对图像生成图像补丁。我已经成功生成了匹配的补丁。但是,代码效率不高,因为我对图像及其对重复了相同的过程。我相信通过堆叠或串联它们可能会使该过程瘫痪,但是我失败了。您能提高效率吗?

channel_nb = 3
img=tf.ones([600,400,3])
target_img=tf.ones([600,400,3])
h, w = 64, 64
ksizes = [1, h, w, 1]
strides = [1, h//2, w//2, 1]
rates=[1, 1, 1, 1]
padding = 'VALID'

# Following part is not efficient, but correct
img_patches = tf.image.extract_image_patches(tf.expand_dims(img, axis=0), ksizes, strides, rates, padding)
img_patches = tf.reshape(img_patches, [tf.reduce_prod(tf.shape(img_patches)[0:3]), h, w, channel_nb])

target_img_patches = tf.image.extract_image_patches(tf.expand_dims(target_img, axis=0), ksizes, strides, rates,                                           padding)
target_img_patches = tf.reshape(target_img_patches, [tf.reduce_prod(tf.shape(target_img_patches)[0:3]), h, w, channel_nb])

1 个答案:

答案 0 :(得分:0)

patches =  tf.image.extract_image_patches(tf.stack([img,target_img]), ksizes, strides, rates, padding)
patches_shape = tf.shape(patches)
patches = tf.reshape(patches, [2, patches_shape[1]*patches_shape[2], h, w, 3])
img_patches = patches[0, :, :, :, :]
targe_img_patches = patches[1, :, :, :, :]