给定一个索引数组index
,比如一个矩阵A
我希望矩阵B
具有A
列的相应排列。
在Numpy,我会做以下事情,
>>> A = np.arange(6).reshape(2,3); A
array([[0, 1, 2],
[3, 4, 5]])
>>> index = [2,0,1]
>>> A[:,index]
array([[2, 0, 1],
[5, 3, 4]])
在MXNet中有 自然 或 高效 方式吗?函数pick()
和take()
似乎不会以这种方式工作。我设法提出以下但是它并不优雅。
>>> mx.nd.take(A.T, mx.nd.array([[2],[0],[1]])).T.reshape((2,3))
[[ 2. 0. 1.]
[ 5. 3. 4.]]
<NDArray 2x3 @cpu(0)>
最后,为了解决这个问题,有没有办法在这里进行这项工作?
更新这是一个稍微优雅,但可能不那么有效(由于换位),上面的版本:
>>> mx.nd.take(A.T, mx.nd.array([2,0,1])).T
[[ 2. 0. 1.]
[ 5. 3. 4.]]
<NDArray 2x3 @cpu(0)>
答案 0 :(得分:2)
您需要的是MXNet中所谓的高级索引。提交了一个PR,用于通过MXNet NDArray的高级索引获取元素,并且还将设置元素的功能添加到NDArray。预计将在1.0版本中发布。