softmax应该应用于哪个维度?
此代码:
%reset -f
import torch.nn as nn
import numpy as np
import torch
my_softmax = nn.Softmax(dim=-1)
mu, sigma = 0, 0.1 # mean and standard deviation
train_dataset = []
image = []
image_x = np.random.normal(mu, sigma, 24).reshape((3 , 4, 2))
train_dataset.append(image_x)
x = torch.tensor(train_dataset).float()
print(x)
print(my_softmax(x))
my_softmax = nn.Softmax(dim=1)
print(my_softmax(x))
打印以下内容:
tensor([[[[-0.1500, 0.0243],
[ 0.0226, 0.0772],
[-0.0180, -0.0278],
[ 0.0782, -0.0853]],
[[-0.0134, -0.1139],
[ 0.0385, -0.1367],
[-0.0447, 0.1493],
[-0.0633, -0.2964]],
[[ 0.0123, 0.0061],
[ 0.1086, -0.0049],
[-0.0918, -0.1308],
[-0.0100, 0.1730]]]])
tensor([[[[ 0.4565, 0.5435],
[ 0.4864, 0.5136],
[ 0.5025, 0.4975],
[ 0.5408, 0.4592]],
[[ 0.5251, 0.4749],
[ 0.5437, 0.4563],
[ 0.4517, 0.5483],
[ 0.5580, 0.4420]],
[[ 0.5016, 0.4984],
[ 0.5284, 0.4716],
[ 0.5098, 0.4902],
[ 0.4544, 0.5456]]]])
tensor([[[[ 0.3010, 0.3505],
[ 0.3220, 0.3665],
[ 0.3445, 0.3230],
[ 0.3592, 0.3221]],
[[ 0.3450, 0.3053],
[ 0.3271, 0.2959],
[ 0.3355, 0.3856],
[ 0.3118, 0.2608]],
[[ 0.3540, 0.3442],
[ 0.3509, 0.3376],
[ 0.3200, 0.2914],
[ 0.3289, 0.4171]]]])
因此,第一个张量在施加softmax之前,第二个张量是对soft施加dim = -1的张量的结果,而第三个张量是softmax施加给dim = 1的张量的结果。
对于第一个softmax的结果,可以看到对应元素之和为1,例如[0.4565,0.5435]-> 0.4565 + 0.5435 == 1。
第二个softmax结果等于1是什么?
我应该选择哪个暗值?
更新:尺寸(3 , 4, 2)
对应于图像尺寸,其中3是RGB值,4是水平像素数(宽度),2是垂直像素数(高度)。这是图像分类问题。我正在使用交叉熵损失函数。另外,我在最后一层中使用softmax以便向后传播概率。
答案 0 :(得分:3)
您有一个1x3x4x2张量train_dataset。您的softmax函数的dim参数确定执行Softmax操作的尺寸。第一维是批次尺寸,第二维是深度,第三维是行,最后一个是列。请查看下面的图片(对不起,这幅画很糟),以了解将dim设置为1时如何执行softmax。
简而言之,您的4x2矩阵的每个对应条目的总和等于1。
更新:应将softmax应用于哪个维度的问题取决于张量存储的数据以及目标是什么。
更新:有关图像分类的任务,请参见pytorch官方网站上的tutorial。它涵盖了在实际数据集上使用pytorch进行图像分类的基础及其简短的教程。尽管该教程不执行Softmax操作,但是您只需要在最后一个完全连接的层的输出上使用torch.nn.functional.log_softmax。有关完整的示例,请参见MNIST classifier with pytorch。将图像展平为完全连接的图层后,图像是RGB还是灰度都没有关系(还要记住,MNIST示例的相同代码可能不适合您,取决于您使用的pytorch版本)。
答案 1 :(得分:0)
对于大多数深度学习问题,我们肯定会分批提出。因此dim始终为1。不要与它混淆。通过我们所说的函数,您可以沿每个批处理的内容进行操作(此处为向量,即如果您有8个类,则其中有8个元素每一行)。您也可以提及dim = -1。