我有一个尺寸为[B,Y,X,N]的4D张量参数,并希望从中选择一个特定的切片n ∈ N
,这样我得到的张量就是尺寸[B,Y,X,1](或[B,Y,X])。
特定切片应该是平均包含最高数字的切片;我得到如此的指数:
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
(形状[B])
我使用gather
或gather_nd
尝试了不同的解决方案,但无法使其发挥作用。有很多帖子与此非常相似,但我无法应用其中的一种解决方案。
我正在运行Tensorflow 1.3,因此可以使用gather
的新奇轴参数。
答案 0 :(得分:0)
在下面的示例代码中,输入的形状为[2,3,4,5]
,结果形状为[2,3,4]
。
主要观点是:
gather_nd
很容易获得一行而不是一列,所以我用tf.transpose
切换了最后两个维度。tf.argmax
(下面indices
)的索引转换为final_idx
中真正可用的索引(请参阅下面的tf.gather_nd
)。转换是通过堆叠三个组件完成的:
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
所以我们可以从[3, 0]
转到
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()
data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)
Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]),
[1,Y]), [-1]), tf.int32)
idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])
inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)
with tf.Session() as sess:
print sess.run(Params)
print sess.run(idx)
print sess.run(inc)
print sess.run(indices)
print sess.run(final_idx)
print sess.run(slice)
[[[[ 22 38 68 49 119]
[ 47 74 111 117 90]
[ 14 32 31 12 75]
[ 93 34 57 3 56]]
[[ 69 21 4 94 39]
[ 83 96 62 102 80]
[ 55 113 48 98 29]
[107 81 67 76 28]]
[[ 53 51 77 66 63]
[ 92 115 118 116 13]
[ 43 78 15 1 0]
[ 99 50 27 60 73]]]
[[[ 97 88 91 64 86]
[ 72 110 26 87 33]
[ 70 30 41 114 5]
[ 95 82 46 16 61]]
[[109 71 45 8 40]
[101 9 23 59 10]
[ 37 65 44 11 19]
[ 42 104 106 105 18]]
[[112 58 7 17 89]
[ 25 79 103 85 20]
[ 35 6 108 100 36]
[ 24 52 2 54 84]]]]
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]]
[[[ 49 117 12 3]
[ 94 102 98 76]
[ 66 116 1 60]]
[[ 97 72 70 95]
[109 101 37 42]
[112 25 35 24]]]