如何在tensorflow中使用索引收集元素

时间:2018-08-04 22:46:29

标签: python tensorflow tensor

例如,

import tensorflow as tf

index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])

如果我使用extract = tf.gather_nd(values, index) 返回是

[[0.4 0.6]
 [0.4 0.6]]

但是,我希望结果是

[[0.8], [0.6]]

其中索引沿轴= 1,但是,在tf.gather_nd中没有轴参数设置。

我该怎么办?谢谢!

1 个答案:

答案 0 :(得分:1)

将范围连接到def sumOfNumber(number): sum = 0 while(number >= 1): temp = int(number % 10) sum += temp number = int(number / 10) return sum main(): sumOfNumber(123) # 6

index

index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)