Tensorflow根据另一个占位符从占位符获取元素

时间:2017-09-17 18:33:28

标签: python tensorflow deep-learning

我在tf编码此操作时遇到了一些麻烦。这是一个例子让我们假设我有一个[n,2]占位符x和一个[n,1]占位符y。 x = [[1,2],[3,4],[5,6]] Y = [1,0,1] 对于来自y的每个元素,我想从第i个2d张量中取出相应的元素。 在该示例中,输出应为[2,3,6]。我尝试了几种技术但没有成功。使用tensorflow有一种简单的方法吗?

谢谢

1 个答案:

答案 0 :(得分:0)

使用tf.gather_ndtf.stacktf.where手动黑客攻击:

import tensorflow as tf

x = tf.convert_to_tensor([[1, 2], [3, 4], [5, 6]])
y = tf.convert_to_tensor([1, 0, 1])

with tf.Session() as sess:
    xx = tf.unstack(x, axis=1)
    ans = tf.where(tf.equal(y, tf.zeros_like(y)), xx[0], xx[1])
    print sess.run(ans)


with tf.Session() as sess:
    idx = tf.range(0, limit=3, delta=1, name='arange')
    idx = tf.stack([idx, y], axis=-1)
    ans = tf.gather_nd(x, idx)
    print sess.run(ans)