我是PyTorch的新手,我想在我的培训和验证循环中高效地评估F1。
到目前为止,我的方法是在GPU上计算预测,然后将其推入CPU,并将其附加到用于训练和验证的向量上。经过培训和验证后,我将使用sklearn对每个时期进行评估。但是,对它显示的代码进行概要分析表明,推送到cpu是一个瓶颈。
for epoch in range(n_epochs):
model.train()
avg_loss = 0
avg_val_loss = 0
train_pred = np.array([])
val_pred = np.array([])
# Training loop (transpose X_batch to fit pretrained (features, samples) style)
for X_batch, y_batch in train_loader:
scores = model(X_batch)
y_pred = F.softmax(scores, dim=1)
train_pred = np.append(train_pred, self.get_vector(y_pred.detach().cpu().numpy()))
loss = loss_fn(scores, self.get_vector(y_batch))
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() / len(train_loader)
model.eval()
# Validation loop
for X_batch, y_batch in val_loader:
with torch.no_grad():
scores = model(X_batch)
y_pred = F.softmax(scores, dim=1)
val_pred = np.append(val_pred, self.get_vector(y_pred.detach().cpu().numpy()))
loss = loss_fn(scores, self.get_vector(y_batch))
avg_val_loss += loss.item() / len(val_loader)
# Model Checkpoint for best validation f1
val_f1 = self.calculate_metrics(train_targets[val_index], val_pred, f1_only=True)
if val_f1 > best_val_f1:
prev_best_val_f1 = best_val_f1
best_val_f1 = val_f1
torch.save(model.state_dict(), self.PATHS['xlm'])
evaluated_epoch = epoch
# Calc the metrics
self.save_metrics(train_targets[train_index], train_pred, avg_loss, 'train')
self.save_metrics(train_targets[val_index], val_pred, avg_val_loss, 'val')
我敢肯定,有一种更有效的方法 a)存储预测,而不必将其推入每批CPU。 b)直接在GPU上计算指标?
由于我是PyTorch的新手,非常感谢您提供任何提示和反馈:)
答案 0 :(得分:0)
您可以在pytorch中自己计算F分数。 F1分数仅针对单类(真/假)分类定义。您唯一需要做的就是汇总数量:
让我们假设您要在softmax中为索引为0
的类计算F1分数。在每一批中,您可以执行以下操作:
predicted_classes = torch.argmax(y_pred, dim=1) == 0
target_classes = self.get_vector(y_batch)
target_true += torch.sum(target_classes == 0).float()
predicted_true += torch.sum(predicted_classes).float()
correct_true += torch.sum(
predicted_classes == target_classes * predicted_classes == 0).float()
处理完所有批次后:
recall = correct_true / target_true
precision = correct_true / predicted_true
f1_score = 2 * precission * recall / (precision + recall)
别忘了处理精度和召回率均为零且根本无法预测所需类别的情况。