我正在尝试使用TensorFlow提取同一图像中一组2D地标周围的多个补丁。
给出形状为[batch_size, num_landmarks, 2]
的2D界标的输入张量和形状为[batch_size, num_rows, num_cols, num_channels]
的输入图像张量,我想返回一个包含[batch_size, num_landmarks, patch_rows, patch_cols, num_channels]
的张量。
现在我们可以假设batch_size=1
,如果是这样,下面的代码将执行上述操作:
im = tf.tile(im, (num_landmarks, 1, 1, 1))
patches = tf.image.extract_glimpse(im, (patch_cols, patch_rows), landmarks, centered=False, normalized=False)
基本上,我会重复输入图像达到具有界标的次数,然后提取瞥见。当我有很多地标时,这当然是疯了,所以我想知道是否存在更好的方法。
编辑:
我认为tf.gather_nd
可以解决问题,因此我正在构建我需要提取补丁的索引。