Tensorflow中的复杂切片操作

时间:2018-09-11 07:00:02

标签: python tensorflow

我对Tensorflow的切片操作感到困惑。我想做的就是在Numpy中这样,

>>> a = np.arange(24).reshape((4,6))
>>> a
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])
>>> print(a[[2,3],[0,1]])
array([12, 19])

但是在Tensorflow中,

>>> a = tf.Variable(np.arange(24).reshape((4,6)))
>>> with tf.Session() as sess:
...  sess.run(tf.global_variables_initializer())
...  print(sess.run(a[[2,3],[0,1]]))

我遇到一个错误,说TypeError: can only concatenate list (not "int") to list。有没有办法在Tensorflow中执行此切片?

谢谢。

1 个答案:

答案 0 :(得分:1)

这是一种方式。但是我已经重新组织了索引([2,0],[3,1])。

a = tf.Variable(np.arange(24).reshape((4, 6)))

sess.run(tf.global_variables_initializer())

print(sess.run(tf.gather_nd(a, [[2,0],[3,1]])))

输出为

  

[12 19]