给定一个类实例列表,我需要使用tf.tensor对其进行索引。例如:
Class Something():
def __init__(self):
self.a = 1
self.b = 2
list = [Something() for a in range(0, 10)]
index_queue = tf.train.range_input_producer(len(list))
index = index_queue.dequeue()
result = list[index]
tensor = function_that_returns_tensor(result)
with tf.Session() as sess:
sess.run(tensor)
上面的代码会出现以下错误:TypeError: list indices must be integers, not Tensor
使用tf.gather(list, index)
会出现以下错误:
TypeError: Expected binary or unicode string, got <__main__.Something object at 0x7f4529fae2b0>
任何帮助都将受到高度赞赏。谢谢!
答案 0 :(得分:0)
问题在于TensorFlow如何工作的核心机制。当您调用诸如tf.train.range_input_producer(len(list))
或tf.constant
之类的TensorFlow方法时,您实际上并未正在运行这些操作。您只是将这些操作添加到TensorFlow计算图中。然后,您必须使用run
实例的tf.Session
方法来运行这些操作并从中获取结果。 TypeError: list indices must be integers, not Tensor
告诉您,您将计算图上的张量作为索引传递给张量,而不是运行产生张量的操作返回的结果。
有关更详细的说明,请参阅this TensorFlow documentation。