如何在Tensorflow中从3D数组中提取行和列

时间:2018-10-18 16:11:39

标签: python-3.x numpy tensorflow

我想在TensorFlow张量上执行以下索引操作。 TensorFlow中等效的操作是什么,以获取bc作为输出?尽管tf.gather_nd文档中有几个示例,但我无法生成等效的indices张量来获得这些结果。

import tensorflow as tf
import numpy as np

a=np.arange(18).reshape((2,3,3))

idx=[2,0,1] #it can be any validing re-ordering index list

#These are the two numpy operations that I want to do in Tensorflow
b=a[:,idx,:]
c=a[:,:,idx] 

# TensorFlow operations

aT=tf.constant(a)
idxT=tf.constant(idx)

# what should be these two indices  
idx1T=tf.reshape(idxT, (3,1)) 
idx2T=tf.reshape(idxT, (1,1,3))

bT=tf.gather_nd(aT, idx1T ) #does not work
cT=tf.gather_nd(aT, idx2T)  #does not work

with tf.Session() as sess:
    b1,c1=sess.run([bT,cT])

print(np.allclose(b,b1))
print(np.allclose(c,c1))

我不局限于tf.gather_nd,任何在GPU上实现相同操作的建议也将有所帮助。

编辑:我已经更新了打字错误的问题:

旧语句:c=a[:,idx]

新声明:c=a[:,:,idx] 我想要实现的也是对列进行重新排序。

1 个答案:

答案 0 :(得分:1)

可以使用axis参数通过tf.gather完成

import tensorflow as tf
import numpy as np

a = np.arange(18).reshape((2,3,3))
idx = [2,0,1]
b = a[:, idx, :]
c = a[:, :, idx]

aT = tf.constant(a)
idxT = tf.constant(idx)
bT = tf.gather(aT, idxT, axis=1)
cT = tf.gather(aT, idxT, axis=2)

with tf.Session() as sess:
    b1, c1=sess.run([bT, cT])

print(np.allclose(b, b1))
print(np.allclose(c, c1))

输出:

True
True