我在tf编码此操作时遇到了一些麻烦。这是一个例子让我们假设我有一个[n,2]占位符x和一个[n,1]占位符y。 x = [[1,2],[3,4],[5,6]] Y = [1,0,1] 对于来自y的每个元素,我想从第i个2d张量中取出相应的元素。 在该示例中,输出应为[2,3,6]。我尝试了几种技术但没有成功。使用tensorflow有一种简单的方法吗?
谢谢
答案 0 :(得分:0)
使用tf.gather_nd
或tf.stack
和tf.where
手动黑客攻击:
import tensorflow as tf
x = tf.convert_to_tensor([[1, 2], [3, 4], [5, 6]])
y = tf.convert_to_tensor([1, 0, 1])
with tf.Session() as sess:
xx = tf.unstack(x, axis=1)
ans = tf.where(tf.equal(y, tf.zeros_like(y)), xx[0], xx[1])
print sess.run(ans)
with tf.Session() as sess:
idx = tf.range(0, limit=3, delta=1, name='arange')
idx = tf.stack([idx, y], axis=-1)
ans = tf.gather_nd(x, idx)
print sess.run(ans)