为2D张量创建索引

时间:2017-09-16 21:05:40

标签: tensorflow

我试图做这样的事情:

假设输入Tensor是(2, 3) Tensor,其值如:

[[1 2]
 [3 4]
 [5 6]]

我通过执行以下操作将2D Tensor展平为1D:

input = tf.reshape(input, [-1])

所以现在输入变为[1 2 3 4 5 6]

但是我也创建了一个1D Tensor来索引前一个Tensor,所以期望的输出是[0 0 1 1 2 2]。我应该如何在TF中创建这个Tensor?

通常,如果输入Tensor的形状为(X, Y)。我想创建一个看起来像这样的1D Tensor:

[0 0 0 ... 0 1 ....1 2 ... 2 ... X-1 ... X-1]

其中每个值重复Y - 1次。

1 个答案:

答案 0 :(得分:3)

这是一种方法;这样做首先创建一个单行索引,然后按原始张量使用tf.tile的列数重复每个指数;重新塑造2d指数可以满足您的需求。

t = tf.constant([[1,2],[3,4],[5,6]])

X, Y = t.shape
idx = tf.range(X.value)
idx_2d = tf.reshape(idx, [-1,1])
idx_2d_full = tf.tile(idx_2d, [1, Y.value])
idx_flat = tf.reshape(idx_2d_full, [-1])

with tf.Session() as sess:
    print(sess.run(idx_flat))

[0 0 1 1 2 2]