我有一个转换层 conv 的输出,其形状为Batch_size x H x W xC。 我还有另一个tenatch,其形状为Batch_size x None x2。后一个张量为bach中的每个示例提供点列表(高度和宽度坐标)(每个示例的列表长度不同)。我想为每个这些点提取Channel维。
我尝试使用tf.gather和tf.batch_gather,但是在这里似乎两者都不是正确的选择。
基本上我想要的是让每个批次 b 遍历这些点:对于每个点 i ,其h_i(高度坐标)和w_i(坐标)并返回 conv [b,h_i,w_j,:]。然后堆叠这些结果。
答案 0 :(得分:1)
这是您可以执行的操作:
import tensorflow as tf
def pick_points(images, coords):
coords = tf.convert_to_tensor(coords)
s = tf.shape(coords)
batch_size, num_coords = s[0], s[1]
# Make batch indices
r = tf.range(batch_size, dtype=coords.dtype)
idx_batch = tf.tile(tf.expand_dims(r, 1), [1, num_coords])
# Full index
idx = tf.concat([tf.expand_dims(idx_batch, 2), coords], axis=2)
# Gather pixels
pixels = tf.gather_nd(images, idx)
# Output has shape [batch_size, num_coords, num_channels]
return pixels
# Test
with tf.Graph().as_default(), tf.Session() as sess:
# 2 x 2 x 3 x 3
images = [
[
[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
],
[
[[19, 20, 21], [22, 23, 24], [25, 26, 27]],
[[28, 29, 30], [31, 32, 33], [34, 35, 36]],
],
]
# 2 x 2 x 2
coords = [
[[0, 1], [1, 2]],
[[1, 0], [1, 1]],
]
pixels = pick_points(images, coords)
print(sess.run(pixels))
# [[[ 4 5 6]
# [16 17 18]]
#
# [[28 29 30]
# [31 32 33]]]