用张量切片张量流量张量

时间:2017-06-28 03:43:58

标签: tensorflow

我正在尝试使用此PR中添加的“高级”,numpy风格的切片,但是我遇到了same issue as the user here

ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_15' (op: 'StridedSlice') with input shapes: [3,2], [1,2], [1,2], [1].

即我想做相当于这个numpy操作(在numpy中工作):

A = np.array([[1,2],[3,4],[5,6]]) 
id_rows = np.array([0,2])
A[id_rows]

然而,对于上述错误,这在TF中不起作用:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([0,2])
A[id_rows]

2 个答案:

答案 0 :(得分:1)

你正在寻找这样的东西:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([[0],[2]]) #Notice the brackets
out = tf.gather_nd(A,id_rows)

答案 1 :(得分:1)

您可以按如下方式对张量进行切片。

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant(np.array([0, 2]).reshape(-1, 1))
out = tf.gather_nd(A,id_rows)
with tf.Session() as session: 
    print(session.run(out))