TensorFlow:根据索引列表

时间:2017-04-01 09:26:25

标签: python tensorflow

我现在正在学习张量流,我有一个我无法通过谷歌找到的问题。我发现在遇到问题时更容易处理问题和查找文档,所以如果这是文档中的某个地方并且我没有应用,我会道歉。

我有一个张量。让我们说它是100 x 1.我们称之为 t1 。我还有一个整数列表,范围为[0,99],大小为5000,称之为 l 。我想将 t1 转换为5000 x 1张量,称之为 t2

关系如下:假设 l 的第i个条目是j。然后,我希望 t2 的第i个条目等于 t1 的第j个条目。

现在,如果这些是numpy数组,我只会这样做:

    t2 = t1[l]

但是我不认为这是在张量流中做到这一点的有效方式,而且它似乎甚至无法工作。

建议?

1 个答案:

答案 0 :(得分:1)

您正在寻找的是tf.gather: https://www.tensorflow.org/api_docs/python/tf/gather

import tensorflow as tf
tf.InteractiveSession()
t1 = tf.random_normal((100, 1))
l = tf.random_uniform((5000, ), minval=0, maxval=99, dtype=tf.int32)
t2 = tf.gather(t1, l)