如何使用TensorFlow张量索引列表?

时间:2016-12-16 15:03:41

标签: python list indexing tensorflow

假设一个包含不可连接对象的列表需要通过查找表进行访问。所以列表索引将是张量对象,但这是不可能的。

 tf_look_up = tf.constant(np.array([3, 2, 1, 0, 4]))
 index = tf.constant(2)
 list = [0,1,2,3,4]

 target = list[tf_look_up[index]]

这将显示以下错误消息。

 TypeError: list indices must be integers or slices, not Tensor

是使用张量索引列表的方法/解决方法吗?

3 个答案:

答案 0 :(得分:11)

tf.gather就是为此而设计的。

只需运行tf.gather(list, tf_look_up[index]),即可获得所需内容。

答案 1 :(得分:2)

Tensorflow实际上支持HashTable。有关详细信息,请参阅documentation

在这里,你可以做的是:

table = tf.contrib.lookup.HashTable(
    tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, list), -1)

然后通过运行

获得所需的输入
target = table.lookup(index)

请注意,如果找不到密钥,-1是默认值。您可能必须根据张量的配置将key_dtypevalue_dtype添加到构造函数中。

答案 2 :(得分:0)

我认为这会有所帮助: How can I convert a tensor into a numpy array in TensorFlow?

"要从张量转换回numpy数组,您只需在转换的张量上运行.eval()。"