通过索引跳过张量流中的一维获取元素

时间:2018-08-02 15:31:18

标签: python numpy tensorflow

这就是我要完成的事情:

m = np.array([
    [[1, 0, 0],
     [1, 0, 0]],
    [[0, 2, 0],
     [0, 2, 0]],
    [[0, 0, 3],
     [0, 0, 3]]
])
i = np.arange(3)
print(m[i, :, i])

>> [[1 1]
 [2 2]
 [3 3]]

使用numpy,一切正常。但是当我尝试用tf做同样的事情时,我遇到了一个错误:

ValueError: Shapes must be equal rank, but are 0 and 1
    From merging shape 1 with other shapes. for 'strided_slice_1/stack_1' (op: 'Pack') with input shapes: [15], [], [15].

我发现了gather,但它仅适用于一维。使用gather_nd,我需要以某种方式构造整个蒙版。

编辑:

看起来可以用gather_nd来解决,但首先需要创建索引掩码:

[
    [
        [0, 0, 0],
        [0, 1, 0],
    ],
    [
        [1, 0, 1],
        [1, 1, 1],
    ],
    [
        [2, 0, 2],
        [2, 1, 2],
    ],
]

0 个答案:

没有答案