Torch:训练后如何检查权重?

时间:2021-02-26 22:17:44

标签: python pytorch

在查看训练期间权重如何变化时,我想知道我做错了什么。

我的损失大幅下降,但似乎初始化权重与训练权重相同。我是不是找错地方了?如果您有任何见解,我将不胜感激!

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# setup GPU/CPU processing
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# initialize model
class mlp1(torch.nn.Module):
    def __init__(self, num_features, num_hidden, num_classes):
        super(mlp1, self).__init__()
        self.num_classes = num_classes
        self.input_layer = torch.nn.Linear(num_features, num_hidden)
        self.out_layer = torch.nn.Linear(num_hidden, num_classes)

    def forward(self, x):
        x = self.input_layer(x)
        x = torch.sigmoid(x)
        logits = self.out_layer(x)
        probas = torch.softmax(logits, dim=1)
        return logits, probas

# instantiate model
model = mlp1(num_features=28*28, num_hidden=100, num_classes=10).to(device)

# check initial weights
weight_check_pre = model.state_dict()['input_layer.weight'][0][0:25]

# optim
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


# download data
train_dataset = datasets.MNIST(root='data',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# data loader
train_dataloader = DataLoader(dataset=train_dataset,
                          batch_size=100,
                          shuffle=True)

# train
NUM_EPOCHS = 1
for epoch in range(NUM_EPOCHS):
    model.train()
    for batch_idx, (features, targets) in enumerate(train_dataloader):
        # send data to device
        features = features.view(-1, 28*28).to(device)
        targets = targets.to(device)
        # forward
        logits, probas = model(features)
        # loss
        loss = F.cross_entropy(logits, targets)
        optimizer.zero_grad()
        loss.backward()
        # now update weights
        optimizer.step()
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'
                   %(epoch+1, NUM_EPOCHS, batch_idx,
                     len(train_dataloader), loss))

# check post training
weight_check_post = model.state_dict()['input_layer.weight'][0][0:25]

# compare
weight_check_pre == weight_check_post  # all equal

1 个答案:

答案 0 :(得分:1)

这是因为两个变量都引用了内存中的同一个对象(字典),因此它们总是彼此相等。

您可以这样做以获得 state_dict 的实际副本。

import copy

# check initial weights
weight_check_pre = copy.deepcopy(model.state_dict()['input_layer.weight'][0][0:25])
...
# check post training
weight_check_post = copy.deepcopy(model.state_dict()['input_layer.weight'][0][0:25])