这就是我要完成的事情:
m = np.array([
[[1, 0, 0],
[1, 0, 0]],
[[0, 2, 0],
[0, 2, 0]],
[[0, 0, 3],
[0, 0, 3]]
])
i = np.arange(3)
print(m[i, :, i])
>> [[1 1]
[2 2]
[3 3]]
使用numpy
,一切正常。但是当我尝试用tf做同样的事情时,我遇到了一个错误:
ValueError: Shapes must be equal rank, but are 0 and 1
From merging shape 1 with other shapes. for 'strided_slice_1/stack_1' (op: 'Pack') with input shapes: [15], [], [15].
我发现了gather,但它仅适用于一维。使用gather_nd,我需要以某种方式构造整个蒙版。
看起来可以用gather_nd
来解决,但首先需要创建索引掩码:
[
[
[0, 0, 0],
[0, 1, 0],
],
[
[1, 0, 1],
[1, 1, 1],
],
[
[2, 0, 2],
[2, 1, 2],
],
]