我有一批图像,我想在会话运行期间指定的不同位置提取补丁。大小总是一样。
如果我想在所有图像中提取相同的补丁,我当然可以只使用tf.slice(images, [px, py, 0], [size, size, 3])
。
但是我想在不同的位置切片,所以我希望px
和py
是向量。
在Numpy中,我不确定如何不使用循环来执行此操作。我会做这样的事情:
result = np.array([image[y:y+size, x:x+size] for image, x, y in zip(images, px, py)])
受此启发,我想到的TensorFlow解决方案还使用一个周期重新实现了tf.slice
,因此begin
现在变成了begin_vector
:
def my_slice(input_, begin_vector, size):
def condition(i, _):
return tf.less(i, tf.shape(input_)[0])
def body(i, r):
sliced = tf.slice(input_[i], begin_vector[i], size)
sliced = tf.expand_dims(sliced, 0)
return i+1, tf.concat((r, sliced), 0)
i = tf.constant(0)
empty_result = tf.zeros((0, *size), tf.float32)
loop = tf.while_loop(
condition, body, [i, empty_result],
[i.get_shape(), tf.TensorShape([None, *size])])
return loop[1]
然后,我可以使用我的位置矢量(这里称为ix
)来运行它:
sess = tf.Session()
images = tf.placeholder(tf.float32, (None, 256, 256, 1))
ix = tf.placeholder(tf.int32, (None, 3))
res = sess.run(
my_slice(images, ix, [10, 10, 1]),
{images: np.random.uniform(size=(2, 256, 256, 1)), ix: [[40, 80, 0], [20, 10, 0]]})
print(res.shape)
我只是想知道是否有一种更漂亮的方式来做到这一点。
PS:我知道人们问过类似的问题。例如,Slicing tensor with list - TensorFlow。但是请注意,我想使用占位符进行切片,因此我所见过的所有解决方案都不适合我。在培训过程中,一切都必须动态。我想使用占位符指定切片。 我不能使用Python的for
。我也不想打开急切执行的功能。
答案 0 :(得分:1)
这是一个无需循环的函数:
import tensorflow as tf
def extract_patches(images, px, py, w, h):
s = tf.shape(images)
ii, yy, xx = tf.meshgrid(tf.range(s[0]), tf.range(h), tf.range(w), indexing='ij')
xx2 = xx + px[:, tf.newaxis, tf.newaxis]
yy2 = yy + py[:, tf.newaxis, tf.newaxis]
# Optional: ensure indices do not go out of bounds
xx2 = tf.clip_by_value(xx2, 0, s[2] - 1)
yy2 = tf.clip_by_value(yy2, 0, s[1] - 1)
idx = tf.stack([ii, yy2, xx2], axis=-1)
return tf.gather_nd(images, idx)
这里是一个例子:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
# Works for images with any size and number of channels
images = tf.placeholder(tf.float32, (None, None, None, None))
patch_xy = tf.placeholder(tf.int32, (None, 2))
patch_size = tf.placeholder(tf.int32, (2,))
px = patch_xy[:, 0]
py = patch_xy[:, 1]
w = patch_size[0]
h = patch_size[1]
patches = extract_patches(images, px, py, w, h)
test = sess.run(patches, {
images: [
# Image 0
[[[ 0], [ 1], [ 2], [ 3]],
[[ 4], [ 5], [ 6], [ 7]],
[[ 8], [ 9], [10], [11]]],
# Image 0
[[[50], [51], [52], [53]],
[[54], [55], [56], [57]],
[[58], [59], [60], [61]]]
],
patch_xy: [[1, 0],
[0, 1]],
patch_size: [3, 2]})
print(test[..., 0])
# [[[ 1. 2. 3.]
# [ 5. 6. 7.]]
#
# [[54. 55. 56.]
# [58. 59. 60.]]]