如何从Tensorflow中的张量中获取特定行?

时间:2016-08-03 12:22:07

标签: python tensorflow

我的张量定义如下:

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]]))

我还有一个从张量中提取的行索引数组:

idx = tf.constant([0, 2])

现在,我想在这些索引中使用temp_var的子集,即idx

我知道要采用单个索引或切片,我们可以执行类似

的操作
temp_var[single_row_index, :]

temp_var[start:end, :]

但是如何获取idx数组指示的行? 像temp_var[idx, :]

这样的东西

1 个答案:

答案 0 :(得分:10)

tf.gather() op完全符合您的需要:它从矩阵中选择行(或者从N维张量中选择一般(N-1)维的切片)。以下是它在您的情况下的工作方式:

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]))
idx = tf.constant([0, 2])

rows = tf.gather(temp_var, idx)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

print(sess.run(rows))  # ==> [[1, 2, 3], [7, 8, 9]]