TensorFlow:使用张量索引另一个不同形状的张量列表

时间:2019-01-21 07:40:12

标签: python tensorflow

如何使用张量访问张量列表索引?当列表中的张量具有相同形状时,tf.gather起作用。例如以下作品

tensor_list = [tf.constant([1,2,3]), tf.constant([4,5,6])]
index = tf.placeholder(tf.int32, None)
out = tf.gather(tensor_list, index, axis=0)
sess = tf.Session()
print(sess.run(out, feed_dict={index: 1}))

[4 5 6]

但是,当列表包含不同形状的张量时,我会遇到ValueError

tensor_list = [tf.constant([1,2,3]), tf.constant([4,5])]
index = tf.placeholder(tf.int32, None)
out = tf.gather(tensor_list, index, axis=0)
sess = tf.Session()
print(sess.run(out, feed_dict={index: 1}))

ValueError: Tried to convert 'params' to a tensor and failed. Error: Dimension 0 in both shapes must be equal, but are 3 and 2. Shapes are [3] and [2].
    From merging shape 0 with other shapes. for 'GatherV2/packed' (op: 'Pack') with input shapes: [3], [2].

我的图包含存储在tensor_list中的多个输出,我想访问特定的输出以根据运行时提供的索引来计算损失函数。

0 个答案:

没有答案