为什么这个PyTorch网络会预测相同的类别,而损耗却没有改善?

时间:2019-12-24 18:38:35

标签: python machine-learning neural-network pytorch cross-entropy

我正在尝试使用从Resnet50图像中提取的特征向量(长度为2048)训练简单的神经网络,特征向量的长度为2048。以下是我的网络体系结构和训练过程。我不明白为什么损失不会减少并且网络会预测相同的类别。 (但是,当网络再次初始化时,它是一个不同的类。因此,它始终会预测类3,但是当再次初始化时,它总是会预测5类,等等。)我试图更改学习率,动量等。在此方面,我们非常感谢您的帮助。谢谢。

class Feedforward(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(2048, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 10)
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
#encode labels & convert data to tensor
le = preprocessing.LabelEncoder()
labels = le.fit_transform(labels)

targets = []
for label in labels:
    targets.append(np.array([label]))
tensor_x = torch.stack([torch.Tensor(i) for i in features]) # transform to torch tensors
tensor_y = torch.stack([torch.LongTensor(i) for i in targets])
print(tensor_x.shape)
print(tensor_y.shape)

#prepare data for training
train_data = torch.utils.data.TensorDataset(tensor_x, tensor_y)
trainloader = torch.utils.data.DataLoader(train_data, shuffle=True)
hidden_size = 256
net = Feedforward(hidden_size)
loss_funct = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
for epoch in range(5):  # loop over the dataset multiple times
    running_loss = 0.0
    total = 0
    correct = 0
    for i, data in enumerate(trainloader, 0):
        inputs, trainlabels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)

        tl = trainlabels.view(1)
        ot = outputs.view(1,10)

        loss = loss_funct(ot, tl) #outputs[0,:], trainlabels[0]
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(ot,1)
        #print(predicted)

        total += trainlabels[0].size(0)
        correct += (predicted == trainlabels[0]).sum().item()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
    print(100*correct/total)
print('Finished Training')

0 个答案:

没有答案