Tensorflow:时间流中索引的查找张量

时间:2018-12-25 04:29:41

标签: python tensorflow

我正在训练RNN,需要在其中使用索引在示例时间流的另一部分中查找值

v = tf.constant([
    [[.1, .2], [.3, .4]],  # timestream 1 values
    [[.6, .5], [.7, .8]]   # timestream 2 values
])
ixs = tf.constant([
    [1, 0], # indices into timestream 1 values
    [0, 1]  # indices into timestream 2 values
])

我正在寻找一个可以进行查找并用张量值和屈服值替换索引的操作:

[
    [[.3, .4], [.1, .2]],
    [[.6, .5], [.7, .8]]
]

tf.gather和tf.gather_nd听起来像是正确的道路,但我不太了解从中得到的结果。

v_at_ix = tf.gather(v, ixs, axis=-1)
sess.run(v_at_ix)
array([[[[0.2, 0.1],
         [0.1, 0.2]],
        [[0.4, 0.3],
         [0.3, 0.4]]],
       [[[0.5, 0.6],
         [0.6, 0.5]],
        [[0.8, 0.7],
         [0.7, 0.8]]]], dtype=float32)

v_at_ix = tf.gather_nd(v, ixs)
sess.run(v_at_ix)
array([[0.6, 0.5],
       [0.3, 0.4]], dtype=float32)

有人知道正确的方法吗?

1 个答案:

答案 0 :(得分:2)

tf.gather 只能基于指定的轴获取切片,并且其索引并置。在v_at_ix = tf.gather(v, ixs, axis=-1)中:

1中的

[1, 0]代表[[[.2],[.4]],[[.5],[.8]]]中的v

0中的

[1, 0]代表[[[.1],[.3]],[[.6],[.7]]]中的v

0中的

[0, 1]代表[[[.1],[.3]],[[.6],[.7]]]中的v

1中的

[0, 1]代表[[[.2],[.4]],[[.5],[.8]]]中的v

tf.gather_nd 能够获取指定索引处的切片,并且其索引是渐进式的。在v_at_ix = tf.gather_nd(v, ixs)中:

1中的

[1, 0]代表[[.6, .5], [.7, .8]]中的v

0中的

[1, 0]代表[.6, .5]中的[[.6, .5], [.7, .8]]

0中的

[0, 1]代表[[.1, .2], [.3, .4]]中的v

1中的

[0, 1]代表[.3, .4]中的[[.1, .2], [.3, .4]]

因此,当我们使用[[[0,1],[0,0]],[[1,0],[1,1]]]时,我们需要的是tf.gather_nd。它可以由[[0,0],[1,1]][[1,0],[0,1]]组成。前者是重复的行号,而后者是ixs。所以我们可以做到

ixs_row = tf.tile(tf.expand_dims(tf.range(v.shape[0]),-1),multiples=[1,v.shape[1]])
ixs = tf.concat([tf.expand_dims(ixs_row,-1),tf.expand_dims(ixs,-1)],axis=-1)
v_at_ix = tf.gather_nd(v,ixs)