我是Pytorch的新手,我正在尝试实现一个简单的CNN以识别MNIST图像。
我正在使用MSE损失作为损失函数并使用SGD作为优化器来训练网络。当我开始培训时,它会给我以下内容
警告:“用户警告:使用与输入大小(torch.Size([64,10])不同的目标大小(torch.Size([64]))。这可能会导致错误的结果广播。请确保它们具有相同的尺寸。”
然后我得到以下
error: "RuntimeError: The size of tensor a (10) must match the size of tensor b
(64) at non-singleton dimension 1".
我尝试使用在其他问题中找到的一些解决方案来解决它,但似乎没有任何效果。这是我如何加载数据集的代码:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train = True, transform = transform, download = True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True)
testset = torchvision.datasets.MNIST(root='./data', train = False, transform = transform, download = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = 64, shuffle = False)
定义我的网络的代码:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#Convolutional layers
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 12, 5)
#Fully connected layers
self.fc1 = nn.Linear(12*4*4, 120)
self.fc2 = nn.Linear(120, 60)
self.out = nn.Linear(60,10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
x = x.reshape(-1, 12*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.out(x)
return x
这是培训:
net = Net()
print(net)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)
epochs = 3
for epoch in range(epochs):
running_loss = 0;
for images, labels in trainloader:
optimizer.zero_grad()
output = net(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
print(f"Training loss: {running_loss/len(trainloader)}")
print('Finished training')
谢谢!
答案 0 :(得分:2)
您正在使用的丢失(nn.MSELoss
)对于此问题是不正确的。您应该使用use nn.CrossEntropyLoss
。
均方差测量输入x和目标y之间的均方误差。这里,输入和目标自然应具有相同的形状。
交叉熵损失计算每个图像在类上的概率。输出将是矩阵N x C,目标将是大小N的向量。(N =批处理大小,C =班数)
由于您的目的是对图像进行分类,所以这是您要使用的。
在您的情况下,您的网络输出将是尺寸为64 x 10的矩阵,目标是尺寸为64的向量。输出矩阵的每一行(应用softmax函数后,表明该类别的概率),在此之后,计算交叉熵损失。 Pytorch的{{1}}结合了softmax操作和损失计算。
有关Pytorch如何计算损失的更多信息,请参阅文档here。