从张量流中的张量提取值

时间:2018-11-02 12:25:01

标签: tensorflow

如果有两个张量矩阵

a = [[1 2 3 4][5 6 7 8]]
b = [[0 1][1 2]],

我们怎么得到这个:

c = [[1 2][6 7]]

即从第一行提取列0和1开始,从第二行提取列1和2开始。

1 个答案:

答案 0 :(得分:0)

这是一种实现方法:

import tensorflow as tf

a = tf.constant([[1, 2, 3, 4],
                 [5, 6, 7, 8]])
b = tf.constant([[0, 1],
                 [1, 2]])
row = tf.range(tf.shape(a)[0])
row = tf.tile(row[:, tf.newaxis], (1, tf.shape(b)[1]))
idx = tf.stack([row, b], axis=-1)
c = tf.gather_nd(a, idx)
with tf.Session() as sess:
    print(sess.run(c))

输出:

[[1 2]
 [6 7]]