深度Q学习:torch.nn.functional.softmax崩溃

时间:2018-12-16 09:13:08

标签: python pytorch softmax

我正在学习一个教程,使用它时softmax函数会崩溃。

newSignals = [0.5, 0., 0., -0.7911, 0.7911]
newState = torch.Tensor(newSignals).float().unsqueeze(0)
probs = F.softmax(self.model(newState), dim=1)

self.model是一个神经网络(torch.nn.module),返回的张量类似于

tensor([[ 0.2699, -0.2176, 0.0333]], grad_fn=<AddmmBackward>)

因此,行probs = F.softmax(self.model(newState), dim=1)使程序崩溃,但是当dim=0起作用时,它就不好了。

1 个答案:

答案 0 :(得分:0)

免责声明:很抱歉,这本来应该是评论,但我不能在评论中写下所有内容。

您确定这是问题所在吗?下面的代码段仅对我有用。

import torch
a = torch.tensor([[ 0.2699, -0.2176,  0.0333]]) 
a.softmax(dim=1)
> tensor([[0.4161, 0.2555, 0.3284]])
a.softmax(dim=0)
> tensor([[1., 1., 1.]])