Tensorflow:基于来自另一个张量的信息,从一个张量中选择不重叠切片的列表

时间:2019-02-27 09:35:28

标签: python tensorflow slice

我有一个转换层 conv 的输出,其形状为Batch_size x H x W xC。 我还有另一个tenatch,其形状为Batch_size x None x2。后一个张量为bach中的每个示例提供点列表(高度和宽度坐标)(每个示例的列表长度不同)。我想为每个这些点提取Channel维。

我尝试使用tf.gather和tf.batch_gather,但是在这里似乎两者都不是正确的选择。

基本上我想要的是让每个批次 b 遍历这些点:对于每个点 i ,其h_i(高度坐标)和w_i(坐标)并返回 conv [b,h_i,w_j,:]。然后堆叠这些结果。

1 个答案:

答案 0 :(得分:1)

这是您可以执行的操作:

import tensorflow as tf

def pick_points(images, coords):
    coords = tf.convert_to_tensor(coords)
    s = tf.shape(coords)
    batch_size, num_coords = s[0], s[1]
    # Make batch indices
    r = tf.range(batch_size, dtype=coords.dtype)
    idx_batch = tf.tile(tf.expand_dims(r, 1), [1, num_coords])
    # Full index
    idx = tf.concat([tf.expand_dims(idx_batch, 2), coords], axis=2)
    # Gather pixels
    pixels = tf.gather_nd(images, idx)
    # Output has shape [batch_size, num_coords, num_channels]
    return pixels

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    # 2 x 2 x 3 x 3
    images = [
        [
            [[ 1,  2,  3], [ 4,  5,  6], [ 7,  8,  9]],
            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
        ],
        [
            [[19, 20, 21], [22, 23, 24], [25, 26, 27]],
            [[28, 29, 30], [31, 32, 33], [34, 35, 36]],
        ],
    ]
    # 2 x 2 x 2
    coords = [
        [[0, 1], [1, 2]],
        [[1, 0], [1, 1]],
    ]
    pixels = pick_points(images, coords)
    print(sess.run(pixels))
    # [[[ 4  5  6]
    #   [16 17 18]]
    #
    #  [[28 29 30]
    #   [31 32 33]]]