TensorFlow:使用占位符将张量切片

时间:2019-05-24 11:38:34

标签: python tensorflow

我有一批图像,我想在会话运行期间指定的不同位置提取补丁。大小总是一样。

如果我想在所有图像中提取相同的补丁,我当然可以只使用tf.slice(images, [px, py, 0], [size, size, 3])

slice same position

但是我想在不同的位置切片,所以我希望pxpy是向量。

slice different positions

在Numpy中,我不确定如何不使用循环来执行此操作。我会做这样的事情:

result = np.array([image[y:y+size, x:x+size] for image, x, y in zip(images, px, py)])

受此启发,我想到的TensorFlow解决方案还使用一个周期重新实现了tf.slice,因此begin现在变成了begin_vector

def my_slice(input_, begin_vector, size):
    def condition(i, _):
        return tf.less(i, tf.shape(input_)[0])
    def body(i, r):
        sliced = tf.slice(input_[i], begin_vector[i], size)
        sliced = tf.expand_dims(sliced, 0)
        return i+1, tf.concat((r, sliced), 0)

    i = tf.constant(0)
    empty_result = tf.zeros((0, *size), tf.float32)
    loop = tf.while_loop(
        condition, body, [i, empty_result],
        [i.get_shape(), tf.TensorShape([None, *size])])
    return loop[1]

然后,我可以使用我的位置矢量(这里称为ix)来运行它:

sess = tf.Session()
images = tf.placeholder(tf.float32, (None, 256, 256, 1))
ix = tf.placeholder(tf.int32, (None, 3))
res = sess.run(
  my_slice(images, ix, [10, 10, 1]),
  {images: np.random.uniform(size=(2, 256, 256, 1)), ix: [[40, 80, 0], [20, 10, 0]]})
print(res.shape)

我只是想知道是否有一种更漂亮的方式来做到这一点。

PS:我知道人们问过类似的问题。例如,Slicing tensor with list - TensorFlow。但是请注意,我想使用占位符进行切片,因此我所见过的所有解决方案都不适合我。在培训过程中,一切都必须动态。我想使用占位符指定切片。 我不能使用Python的for我也不想打开急切执行的功能。

1 个答案:

答案 0 :(得分:1)

这是一个无需循环的函数:

import tensorflow as tf

def extract_patches(images, px, py, w, h):
    s = tf.shape(images)
    ii, yy, xx = tf.meshgrid(tf.range(s[0]), tf.range(h), tf.range(w), indexing='ij')
    xx2 = xx + px[:, tf.newaxis, tf.newaxis]
    yy2 = yy + py[:, tf.newaxis, tf.newaxis]
    # Optional: ensure indices do not go out of bounds
    xx2 = tf.clip_by_value(xx2, 0, s[2] - 1)
    yy2 = tf.clip_by_value(yy2, 0, s[1] - 1)
    idx = tf.stack([ii, yy2, xx2], axis=-1)
    return tf.gather_nd(images, idx)

这里是一个例子:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    # Works for images with any size and number of channels
    images = tf.placeholder(tf.float32, (None, None, None, None))
    patch_xy = tf.placeholder(tf.int32, (None, 2))
    patch_size = tf.placeholder(tf.int32, (2,))
    px = patch_xy[:, 0]
    py = patch_xy[:, 1]
    w = patch_size[0]
    h = patch_size[1]
    patches = extract_patches(images, px, py, w, h)
    test = sess.run(patches, {
        images: [
            # Image 0
            [[[ 0], [ 1], [ 2], [ 3]],
             [[ 4], [ 5], [ 6], [ 7]],
             [[ 8], [ 9], [10], [11]]],
            # Image 0
            [[[50], [51], [52], [53]],
             [[54], [55], [56], [57]],
             [[58], [59], [60], [61]]]
        ],
        patch_xy: [[1, 0],
                   [0, 1]],
        patch_size: [3, 2]})
    print(test[..., 0])
    # [[[ 1.  2.  3.]
    #   [ 5.  6.  7.]]
    #
    #  [[54. 55. 56.]
    #   [58. 59. 60.]]]