我正在尝试在输入数据的中间表示形式中计算一些统计信息。使用.register_forward_hook
是很自然的。但是,由于此统计信息的计算相当长的运行时间。我只能负担一次(在我的情况下,仅是第一批验证数据)。
钩子函数的骨架看起来像
stat_list = []
def hook(self, input, output):
if model.training == False:
# compute stat
statistics_list.append(stat)
但是,我不确定如何正确执行此操作。我已经尝试了两件事
-在验证循环(i
)中访问for i, (X_val, y_val) in val_dataloader
stat_list = []
for hook(self, input, outpu):
if (model.training == False) and (i == 0):
# do something
flag = True
stat_list = []
for hook(self, input, outpu):
if (model.training == False) and (flag == True):
flag = False
# do something
但是它们都不起作用。
能帮我些忙吗?
为了便于讨论,我提供了以下可运行代码。
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.models import AlexNet
from torchvision.datasets import FakeData
from sklearn.metrics import accuracy_score
NUM_CLASSES = 2
LR = 1e-4
BATCH_SIZE = 32
MAX_EPOCH = 2
device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")
transform = T.ToTensor()
train_dataset = FakeData(size=800, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
val_dataset = FakeData(size=200, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
model = AlexNet(num_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
def hook(self, input, output):
if model.training == False:
print("validation...")
model.features.register_forward_hook(hook)
for epoch in range(MAX_EPOCH):
model.train()
print("epoch %d / %d" % ((epoch + 1), MAX_EPOCH))
for i, (X_train, y_train) in enumerate(train_dataloader):
X_train = X_train.type(torch.float32).to(device)
y_train = y_train.type(torch.int64).to(device)
optimizer.step()
score = model(X_train)
loss = criterion(input=score, target=y_train)
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print("\tloss: %.5f" % loss.item())
model.eval()
y_pred_list = list()
y_val_list = list()
for X_val, y_val in val_dataloader:
X_val = X_val.type(torch.float32).to(device)
score = model(X_val)
y_pred_list.extend(torch.topk(score, k=1, dim=1)[1].detach().squeeze().cpu().numpy())
y_val_list.extend(y_val)
print("\tvalidation accuracy: %.5f" % accuracy_score(y_true=y_val_list, y_pred=y_pred_list))
下面是训练和验证日志,由于我只是在hook函数中设置了if model.training == False
,因此可以预期。
epoch 1 / 2
loss: 0.69933
loss: 0.69549
validation...
validation...
validation...
validation...
validation...
validation...
validation...
validation accuracy: 0.46000
epoch 2 / 2
loss: 0.70441
loss: 0.69369
validation...
validation...
validation...
validation...
validation...
validation...
validation...
validation accuracy: 0.54000