使用Adam优化器在FashionMNIST上进行逻辑回归训练时出错

时间:2020-10-28 17:53:07

标签: python pytorch logistic-regression adam

数据集为FashionMNIST(784输入,10输出)。我正在尝试使用Adam优化器(也将其编码)来训练逻辑回归:

weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()

bias = torch.zeros(10, requires_grad=True)

optimizer = Adam([weights, bias])
criterion = nn.CrossEntropyLoss()

火车功能是:

def train_logistic_regression(weights, bias, batch, loss, optimizer):

    inputs, labels = batch

    inputs = inputs.view(inputs.shape[0], -1)

    optimizer.zero_grad()
    y_pred = torch.sigmoid(weights@inputs + bias) # there must be the problem
    loss = criterion(y_pred, labels)
    loss.backward()
    optimizer.step()


from IPython.display import clear_output


for epoch in range(1, 5):

    for batch in train_dataloader: # have to go with batches
      metrics = train_logistic_regression(weights, bias, batch, criterion, optimizer)

每次我收到错误消息:

RuntimeError                              Traceback (most recent call last)
<ipython-input-161-408b80d71db1> in <module>()
      5 
      6     for batch in train_dataloader:
----> 7       metrics = train_logistic_regression(weights, bias, batch, criterion, optimizer)
      8 
      9 

<ipython-input-160-9c2f95ee56ee> in train_logistic_regression(weights, bias, batch, loss, optimizer)
      6 
      7     optimizer.zero_grad()
----> 8     y_pred = torch.sigmoid(weights@inputs + bias)
      9     # y_pred = model(inputs)
     10     loss = criterion(y_pred, labels)

RuntimeError: size mismatch, m1: [784 x 10], m2: [128 x 784] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

如果有人能帮助我,将不胜感激。

1 个答案:

答案 0 :(得分:0)

应该使用y_pred = torch.sigmoid(weights@inputs + bias)代替y_pred = torch.sigmoid(inputs.mm(weights) + bias)