我想在TensorFlow张量上执行以下索引操作。
TensorFlow中等效的操作是什么,以获取b
和c
作为输出?尽管tf.gather_nd
文档中有几个示例,但我无法生成等效的indices
张量来获得这些结果。
import tensorflow as tf
import numpy as np
a=np.arange(18).reshape((2,3,3))
idx=[2,0,1] #it can be any validing re-ordering index list
#These are the two numpy operations that I want to do in Tensorflow
b=a[:,idx,:]
c=a[:,:,idx]
# TensorFlow operations
aT=tf.constant(a)
idxT=tf.constant(idx)
# what should be these two indices
idx1T=tf.reshape(idxT, (3,1))
idx2T=tf.reshape(idxT, (1,1,3))
bT=tf.gather_nd(aT, idx1T ) #does not work
cT=tf.gather_nd(aT, idx2T) #does not work
with tf.Session() as sess:
b1,c1=sess.run([bT,cT])
print(np.allclose(b,b1))
print(np.allclose(c,c1))
我不局限于tf.gather_nd
,任何在GPU上实现相同操作的建议也将有所帮助。
旧语句:c=a[:,idx]
,
新声明:c=a[:,:,idx]
我想要实现的也是对列进行重新排序。
答案 0 :(得分:1)
可以使用axis
参数通过tf.gather
完成
import tensorflow as tf
import numpy as np
a = np.arange(18).reshape((2,3,3))
idx = [2,0,1]
b = a[:, idx, :]
c = a[:, :, idx]
aT = tf.constant(a)
idxT = tf.constant(idx)
bT = tf.gather(aT, idxT, axis=1)
cT = tf.gather(aT, idxT, axis=2)
with tf.Session() as sess:
b1, c1=sess.run([bT, cT])
print(np.allclose(b, b1))
print(np.allclose(c, c1))
输出:
True
True