如何使用TensorFlow张量索引类实例列表

时间:2017-08-08 15:22:20

标签: python tensorflow

给定一个类实例列表,我需要使用tf.tensor对其进行索引。例如:

Class Something(): 
     def __init__(self):  
         self.a = 1
         self.b = 2

list = [Something() for a in range(0, 10)]
index_queue = tf.train.range_input_producer(len(list))
index = index_queue.dequeue()
result = list[index]
tensor = function_that_returns_tensor(result)
with tf.Session() as sess:
     sess.run(tensor)

上面的代码会出现以下错误:TypeError: list indices must be integers, not Tensor

使用tf.gather(list, index)会出现以下错误:

TypeError: Expected binary or unicode string, got  <__main__.Something object at 0x7f4529fae2b0> 

任何帮助都将受到高度赞赏。谢谢!

1 个答案:

答案 0 :(得分:0)

问题在于TensorFlow如何工作的核心机制。当您调用诸如tf.train.range_input_producer(len(list))tf.constant之类的TensorFlow方法时,您实际上并未正在运行这些操作。您只是将这些操作添加到TensorFlow计算图中。然后,您必须使用run实例的tf.Session方法来运行这些操作并从中获取结果。 TypeError: list indices must be integers, not Tensor告诉您,您将计算图上的张量作为索引传递给张量,而不是运行产生张量的操作返回的结果。

有关更详细的说明,请参阅this TensorFlow documentation