如何查看两个类别的准确度,而不是在 bert 情感分析中单独显示每个类别的准确度

时间:2021-03-29 22:47:48

标签: python sentiment-analysis bert-language-model

我刚开始使用 bert,老实说我觉得有点迷茫。我一直在尝试各种在线 github/kaggle 代码,看看它在我的数据集上是如何工作的。无论如何,我获得了 bert 的准确性以及 f1 分数,但是每个标签的准确性是单独显示的,而不是两者的加权平均值。此外,精度显示为分数。我如何设法获得两个标签的加权平均值?我附上了我的结果

识别f1分数和准确率的代码如下:

def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average = 'weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    
    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy:{len(y_preds[y_preds==label])}/{len(y_true)}\n')

。 . . . . . .

, predictions, true_val = evaluate(dataloader_val)  #why _ ? reason behind this is evaluate function return 3 values and i don't require the 1st value i.e., loss_val_avg
accuracy_per_class(predictions, true_val)

结果如下所示:

   100%|██████████| 5/5 [00:00<00:00,  7.39it/s]
    Class: 0
    Accuracy:59/81
    
    Class: 1
    Accuracy:54/69
    
    
    F1: 0.7537174638487207

但是我希望我的准确度显示如下:0.65272 而不是单独类别的一部分

1 个答案:

答案 0 :(得分:0)

你可以使用这个循环:

for label in np.unique(labels_flat):
    y_preds = preds_flat[labels_flat==label]
    y_true = labels_flat[labels_flat==label]
    print(f'Class: {label_dict_inverse[label]}')
    print(f'Accuracy:{len(y_preds[y_preds==label])}/{len(y_true)}\n')

并将其扩展为:

total_correct = 0
total = 0

for label in np.unique(labels_flat):
    y_preds = preds_flat[labels_flat==label]
    y_true = labels_flat[labels_flat==label]
    a = len(y_preds[y_preds==label])
    b = len(y_true)
    total_correct += a
    total += b
    print(f'Class: {label_dict_inverse[label]}')
    print(f'Accuracy:{a}/{b}\n')

print(f`Overall accuracy: {total_correct/total:.2f}\n')

(这只是一点点黑客,基于知道你有类的准确性。如果你确定你从不关心它们,肯定会有更有效的方法,但因为这段代码永远不会一个瓶颈,特别是考虑到上下文是一个 BERT 模型,担心它会过早优化。)