将tf.extract_image_patches
与填充'SAME'
一起使用将导致某些补丁包含填充(很好)。
是否有一种简单的方法来获取TensorFlow布尔蒙版,该蒙版会屏蔽所有包含填充的补丁?还是我需要重新实现填充过程?
答案 0 :(得分:0)
我当前的解决方案是添加一个表示位标志的附加通道。
提取图像补丁后,对于填充通道,位标记为CMD ["flask", "run", "--host=0.0.0.0"]
,对于非填充通道位标记为0
。
完整解决方案:
1
上面的代码中的input_tensor = tf.random.normal([10, 28, 28, 1])
window_shape, strides, padding = (4, 4), (2, 2), 'SAME'
# ----------------------------
bits = tf.ones([tf.shape(input_tensor)[0], input_tensor.shape[1], input_tensor.shape[2], 1])
input_for_patching = tf.concat([input_tensor, bits], axis=-1)
patches = tf.extract_image_patches(input_for_patching, ksizes=(1, *window_shape, 1), strides=(1, *strides, 1), rates=(1, 1, 1, 1), padding=padding)
patches_shape = patches.shape
patches = tf.reshape(patches, [-1, *window_shape, input_tensor.shape[3] + 1])
padding_mask = tf.to_float(tf.reduce_all(tf.equal(patches[:, :, :, -1:], 1.0), [1, 2, 3]))
patches = tf.reshape(patches[:, :, :, :-1], [-1, patches_shape[1], patches_shape[2], window_shape[0] * window_shape[1] * input_tensor.shape[3]])
是我所需要的。
如果某人拥有更短,更优雅和/或更完整的版本,请随时分享。