如何使用索引迭代一维张量?

时间:2019-06-25 14:07:27

标签: python tensorflow

我想遍历张量并获取每个元素的索引。

例如...

tensor = tf.constant([1,2,3])

for idx, elem in enumerate(tensor):
    print(idx, elem)

所需的输出:

0 1
1 2
2 3

2 个答案:

答案 0 :(得分:0)

如果您需要将一维张量与索引配对,请使用tf.stacktf.range(兼容TF 1.x和2.0):

tf.stack([tf.range(tf.shape(tensor)[0]), tensor], axis=1)
# <tf.Tensor 'stack:0' shape=(3, 2) dtype=int32>

无论您需要做什么,都可以不必实际遍历男高音。

答案 1 :(得分:0)

启用急切执行

import tensorflow as tf
tf.enable_eager_execution()
tensor = tf.constant([1,2,3])
for idx, elem in enumerate(tensor):
    tf.print(idx, elem)

0 1
1 2
2 3