Keras张量 - 获得指数来自另一个张量的值

时间:2017-10-02 13:52:47

标签: python-3.x tensorflow keras slice tensor

假设我有这两个张量:

  • valueMatrix,形状为(?, 3),其中?是批量大小
  • indexMatrix,形如(?, 1)

我想从valueMatrix中包含的索引中检索indexMatrix的值。

示例(伪代码):

valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float 
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int

我希望从这个例子中做出类似的事情:

valueMatrix[indexMatrix] --> returns --> [[15],[4]]

我更喜欢Tensorflow而不是其他后端,但答案必须与使用Lambda图层或其他合适图层的Keras模型兼容。

1 个答案:

答案 0 :(得分:3)

import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])

# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)

with tf.Session() as sess:
    print(sess.run(values))
#[[15]
# [ 4]]