如何使用给定索引索引numpy数组?

时间:2019-11-15 04:01:09

标签: python numpy

问题:如何使用给定索引对numpy数组进行索引?

说明

在强化学习中,我得到了许多与不同状态相对应的离散分布,如下所示:

import numpy as np
distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]])

# array([[0.1, 0.2, 0.7],  # \pi(s0)
#        [0.3, 0.3, 0.4],  # \pi(s1)
#        [0.2, 0.2, 0.6]]) # \pi(s2)

然后,我想分别获得在状态s0下采取行动0,在状态s1采取行动2和在状态s2采取行动1的概率。

所以我将索引值存储在如下数组中:

actions = np.array([[0],[2],[1]])

# array([[0],  # taking action 0 in state s0
#        [2],  # taking action 2 in state s1
#        [1]]) # taking action 1 in state s2

我希望得到的东西。

我想使用distributionsactions编制索引,并希望得到如下结果:

# array([0.1,0.4,0.2])
# or 
# array([[0.1],
#        [0.4],
#        [0.2]])

我尝试过的。

我已经尝试过np.take(distributions, actions),但是重新调整array([0.1, 0.7, 0.2])显然是我想要的。 而且我还尝试了distributions[:,actions],这给了我另一个错误的答案,如下:

array([[0.1, 0.7, 0.2],
       [0.3, 0.4, 0.3],
       [0.2, 0.6, 0.2]])         

问题

我该怎么解决这个问题?

1 个答案:

答案 0 :(得分:3)

In [614]: distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]]) 
     ...:                                                                       
In [615]: actions = np.array([[0],[2],[1]])  

使用[0,1,2]行索引:

In [616]: distributions[np.arange(3), actions]                                  
Out[616]: 
array([[0.1, 0.3, 0.2],
       [0.7, 0.4, 0.6],
       [0.2, 0.3, 0.2]])

哎呀,actions是(3,1)形状,它与(3,)一起广播以产生(3,3)选择。相反,我们想使用(3,)形的actions

In [617]: distributions[np.arange(3), actions.ravel()]                          
Out[617]: array([0.1, 0.4, 0.2])

或获得(3,1)结果:

In [619]: distributions[[[0],[1],[2]], actions]                                 
Out[619]: 
array([[0.1],
       [0.4],
       [0.2]])