我正在训练RNN,需要在其中使用索引在示例时间流的另一部分中查找值
v = tf.constant([
[[.1, .2], [.3, .4]], # timestream 1 values
[[.6, .5], [.7, .8]] # timestream 2 values
])
ixs = tf.constant([
[1, 0], # indices into timestream 1 values
[0, 1] # indices into timestream 2 values
])
我正在寻找一个可以进行查找并用张量值和屈服值替换索引的操作:
[
[[.3, .4], [.1, .2]],
[[.6, .5], [.7, .8]]
]
tf.gather和tf.gather_nd听起来像是正确的道路,但我不太了解从中得到的结果。
v_at_ix = tf.gather(v, ixs, axis=-1)
sess.run(v_at_ix)
array([[[[0.2, 0.1],
[0.1, 0.2]],
[[0.4, 0.3],
[0.3, 0.4]]],
[[[0.5, 0.6],
[0.6, 0.5]],
[[0.8, 0.7],
[0.7, 0.8]]]], dtype=float32)
v_at_ix = tf.gather_nd(v, ixs)
sess.run(v_at_ix)
array([[0.6, 0.5],
[0.3, 0.4]], dtype=float32)
有人知道正确的方法吗?
答案 0 :(得分:2)
tf.gather 只能基于指定的轴获取切片,并且其索引并置。在v_at_ix = tf.gather(v, ixs, axis=-1)
中:
1
中的 [1, 0]
代表[[[.2],[.4]],[[.5],[.8]]]
中的v
。
0
中的 [1, 0]
代表[[[.1],[.3]],[[.6],[.7]]]
中的v
。
0
中的 [0, 1]
代表[[[.1],[.3]],[[.6],[.7]]]
中的v
。
1
中的 [0, 1]
代表[[[.2],[.4]],[[.5],[.8]]]
中的v
。
tf.gather_nd 能够获取指定索引处的切片,并且其索引是渐进式的。在v_at_ix = tf.gather_nd(v, ixs)
中:
1
中的 [1, 0]
代表[[.6, .5], [.7, .8]]
中的v
。
0
中的 [1, 0]
代表[.6, .5]
中的[[.6, .5], [.7, .8]]
。
0
中的 [0, 1]
代表[[.1, .2], [.3, .4]]
中的v
。
1
中的 [0, 1]
代表[.3, .4]
中的[[.1, .2], [.3, .4]]
。
因此,当我们使用[[[0,1],[0,0]],[[1,0],[1,1]]]
时,我们需要的是tf.gather_nd
。它可以由[[0,0],[1,1]]
和[[1,0],[0,1]]
组成。前者是重复的行号,而后者是ixs
。所以我们可以做到
ixs_row = tf.tile(tf.expand_dims(tf.range(v.shape[0]),-1),multiples=[1,v.shape[1]])
ixs = tf.concat([tf.expand_dims(ixs_row,-1),tf.expand_dims(ixs,-1)],axis=-1)
v_at_ix = tf.gather_nd(v,ixs)