我的神经网络末端有一个softmax功能。
我想拥有torch.tensor的概率。为此,我正在使用 torch.tensor(nn.softmax(x))
并收到错误RuntimeError: Could not infer dtype of Softmax
。
请让我知道我在这里做错了还是有其他解决方法。
答案 0 :(得分:1)
nn.Softmax
是一门课。您可以像这样使用它:
import torch
x = torch.tensor([10., 3., 8.])
softmax = torch.nn.Softmax(dim=0)
probs = softmax(x)
或者,您可以使用功能性API torch.nn.functional.softmax
:
import torch
x = torch.tensor([10., 3., 8.])
probs = torch.nn.functional.softmax(x, dim=0)
它们是等效的。在这两种情况下,您都可以检查type(probs)
是<class 'torch.Tensor'>
。