将nn.Softmax转换为torch.tensor

时间:2020-07-30 13:18:44

标签: python pytorch

我的神经网络末端有一个softmax功能。 我想拥有torch.tensor的概率。为此,我正在使用 torch.tensor(nn.softmax(x))并收到错误RuntimeError: Could not infer dtype of Softmax

请让我知道我在这里做错了还是有其他解决方法。

1 个答案:

答案 0 :(得分:1)

nn.Softmax是一门课。您可以像这样使用它:

import torch

x = torch.tensor([10., 3., 8.])
softmax = torch.nn.Softmax(dim=0)

probs = softmax(x)

或者,您可以使用功能性API torch.nn.functional.softmax

import torch

x = torch.tensor([10., 3., 8.])

probs = torch.nn.functional.softmax(x, dim=0)

它们是等效的。在这两种情况下,您都可以检查type(probs)<class 'torch.Tensor'>