匹配PyTorch张量尺寸

时间:2020-09-25 15:35:30

标签: python machine-learning pytorch mnist

目前,在我的训练功能中,关于张量的维数存在一些问题。我正在使用MNIST数据集,因此有10个可能的目标,并且最初使用10个训练批处理量编写了原型代码,回想起来这不是最明智的选择。在某些较早的测试中,结果不佳,增加训练迭代次数没有任何好处。在尝试增加批处理大小之后,我意识到我所写的内容并不那么普遍,而且我很可能从未对正确的数据进行过培训。下面是我的训练功能:

def Train(tLoops, Lrate):
    for _ in range(tLoops):
        tempData = train_data.view(batch_size_train, 1, 1, -1)
        output = net(tempData)
        trainTarget = train_targets
        criterion = nn.MSELoss()
        print("target:", trainTarget.size())
        print("Output:", output.size())
        loss = criterion(output, trainTarget.float())
        # print("error is:", loss)
        net.zero_grad()  # zeroes the gradient buffers of all parameters
        loss.backward()
        for j in net.parameters():
            j.data.sub_(j.grad.data * Lrate)

打印功能返回的

target: torch.Size([100])
Output: torch.Size([100, 1, 1, 10])

在计算损失的行上的错误消息之前;

RuntimeError: The size of tensor a (10) must match the size of tensor b (100) at non-singleton dimension 3

第一个打印目标是每个图像的相应地面真实值的一维列表。输出包含这100个样本中每个样本的神经网络输出,因此有一个10 x 100的列表,但是从先前将数据从28 x 28略读和重塑为1 x 784来看,我似乎不必要地拥有额外的维度。 PyTorch是否提供删除这些内容的方法?我在文档中找不到任何内容,或者还有其他可能是我的问题吗?

1 个答案:

答案 0 :(得分:1)

您的训练脚本中有几个问题。我将在下面对它们中的每一个进行处理。

  1. 首先,您不应该手工进行数据批处理。 Pytorch / torchvision具有此功能,使用数据集和数据加载器:https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html

  2. 也不要手动更新网络参数。使用优化程序:https://pytorch.org/docs/stable/optim.html。在您的情况下,没有动力的SGD将具有相同的效果。

  3. 您输入的维数似乎是错误的,对于MNIST,如果要训练MLP,则输入张量应该为(batch_size,1,28,28)或(batch_size,784)。此外,您网络的输出应为(batch_size,10)