如何使用张量访问张量列表索引?当列表中的张量具有相同形状时,tf.gather起作用。例如以下作品
tensor_list = [tf.constant([1,2,3]), tf.constant([4,5,6])]
index = tf.placeholder(tf.int32, None)
out = tf.gather(tensor_list, index, axis=0)
sess = tf.Session()
print(sess.run(out, feed_dict={index: 1}))
[4 5 6]
但是,当列表包含不同形状的张量时,我会遇到ValueError
tensor_list = [tf.constant([1,2,3]), tf.constant([4,5])]
index = tf.placeholder(tf.int32, None)
out = tf.gather(tensor_list, index, axis=0)
sess = tf.Session()
print(sess.run(out, feed_dict={index: 1}))
ValueError: Tried to convert 'params' to a tensor and failed. Error: Dimension 0 in both shapes must be equal, but are 3 and 2. Shapes are [3] and [2].
From merging shape 0 with other shapes. for 'GatherV2/packed' (op: 'Pack') with input shapes: [3], [2].
我的图包含存储在tensor_list中的多个输出,我想访问特定的输出以根据运行时提供的索引来计算损失函数。