用索引列表切片火炬张量

时间:2020-03-29 01:59:09

标签: python pytorch slice

我正在做一个强化学习项目,并且我试图获得一个表示所有给定动作的预期收益的张量。我选择了一个大小为batch的动作,其张量很长,其值为零或一个(两个潜在动作)。对于大小为batch * action_size的每个操作,我都有一个预期的张量,并且我想要一个大小为batch的张量。

例如,如果批处理大小为4,则我有

action = tensor([1,0,0,1])
expectedReward = tensor([[3,7],[5,9],[-1,12],[0,1]])

我想要的是

rewardForActions = tensor([7,5,-1,1])

我以为this会回答我的问题,但是根本不一样,因为如果我采用该解决方案,它将以4 * 4张量结束,从每行中选择4次,而不是一次。

有什么想法吗?

1 个答案:

答案 0 :(得分:2)

您可以

rewardForActions = expectedReward.index_select(1, action).diagonal()  
# tensor([ 7,  5, -1,  1])