在PyTorch中仅调用一次向前挂钩功能

时间:2019-11-21 21:14:34

标签: pytorch

问题

我正在尝试在输入数据的中间表示形式中计算一些统计信息。使用.register_forward_hook是很自然的。但是,由于此统计信息的计算相当长的运行时间。我只能负担一次(在我的情况下,仅是第一批验证数据)。

钩子函数的骨架看起来像

stat_list = []
def hook(self, input, output): 
   if model.training == False:
       # compute stat
       statistics_list.append(stat)

但是,我不确定如何正确执行此操作。我已经尝试了两件事 -在验证循环(i)中访问for i, (X_val, y_val) in val_dataloader

stat_list = []
for hook(self, input, outpu):
  if (model.training == False) and (i == 0):
    # do something
  • 设置全局标志并在循环内对其进行修改。
flag = True
stat_list = []
for hook(self, input, outpu):
  if (model.training == False) and (flag == True):
    flag = False
    # do something

但是它们都不起作用。

能帮我些忙吗?

为了便于讨论,我提供了以下可运行代码。

import numpy as np
import torch

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.models import AlexNet
from torchvision.datasets import FakeData

from sklearn.metrics import accuracy_score

NUM_CLASSES = 2

LR = 1e-4
BATCH_SIZE = 32
MAX_EPOCH = 2

device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

transform = T.ToTensor()
train_dataset = FakeData(size=800, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
val_dataset = FakeData(size=200, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = AlexNet(num_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

def hook(self, input, output):
    if model.training == False:
        print("validation...")

model.features.register_forward_hook(hook)

for epoch in range(MAX_EPOCH):
    model.train()
    print("epoch %d / %d" % ((epoch + 1), MAX_EPOCH))
    for i, (X_train, y_train) in enumerate(train_dataloader):
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.int64).to(device)

        optimizer.step()

        score = model(X_train)
        loss = criterion(input=score, target=y_train)
        loss.backward()

        optimizer.step()

        if (i + 1) % 10 == 0:
            print("\tloss: %.5f" % loss.item())

    model.eval()
    y_pred_list = list()
    y_val_list = list()
    for X_val, y_val in val_dataloader:
        X_val = X_val.type(torch.float32).to(device)

        score = model(X_val)
        y_pred_list.extend(torch.topk(score, k=1, dim=1)[1].detach().squeeze().cpu().numpy())
        y_val_list.extend(y_val)

    print("\tvalidation accuracy: %.5f" % accuracy_score(y_true=y_val_list, y_pred=y_pred_list))

下面是训练和验证日志,由于我只是在hook函数中设置了if model.training == False,因此可以预期。

epoch 1 / 2
    loss: 0.69933
    loss: 0.69549
validation...
validation...
validation...
validation...
validation...
validation...
validation...
    validation accuracy: 0.46000
epoch 2 / 2
    loss: 0.70441
    loss: 0.69369
validation...
validation...
validation...
validation...
validation...
validation...
validation...
    validation accuracy: 0.54000

0 个答案:

没有答案