在tfhub重新训练脚本中计算F1得分,精度,召回率

时间:2018-11-01 13:50:07

标签: python tensorflow

我正在使用tensorflow hub进行图像再训练分类任务。张量流脚本retrain.py默认情况下计算cross_entropy和准确性。

train_accuracy, cross_entropy_value = sess.run([evaluation_step, cross_entropy],feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth})

我想获得F1得分,准确性,召回率和混乱矩阵。如何使用此脚本获取这些值?

1 个答案:

答案 0 :(得分:5)

下面,我提供了一种使用 scikit-learn 软件包计算所需指标的方法。

您可以使用precision_recall_fscore_support方法计算F1得分,准确性和召回率,并使用confusion_matrix方法计算混淆矩阵:

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

这两种方法都采用两个类似一维数组的对象,分别存储地面真相和预测标签。

在提供的代码中,用于训练数据的真实标签存储在10541060行中的train_ground_truth变量中,而validation_ground_truth存储地面变量,验证数据的真实标签,并在第1087行中定义。

用于计算预测类标签的张量由add_evaluation_step函数定义并返回。您可以修改第1034行以捕获该张量对象:

evaluation_step, prediction = add_evaluation_step(final_tensor, ground_truth_input)
# now prediction stores the tensor object that 
# calculates predicted class labels

现在,您可以更新第1076行,以便在调用prediction时评估sess.run()

train_accuracy, cross_entropy_value, train_predictions = sess.run(
    [evaluation_step, cross_entropy, prediction],
    feed_dict={bottleneck_input: train_bottlenecks,
               ground_truth_input: train_ground_truth})

# train_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(train_precision,
 train_recall,
 train_f1_score, _) = precision_recall_fscore_support(y_true=train_ground_truth,
                                                      y_pred=train_predictions,
                                                      average='micro')
# calculate confusion matrix
train_confusion_matrix = confusion_matrix(y_true=train_ground_truth,
                                          y_pred=train_predictions)

类似地,您可以通过修改第1095行来计算验证子集的指标:

validation_summary, validation_accuracy, validation_predictions = sess.run(
    [merged, evaluation_step, prediction],
    feed_dict={bottleneck_input: validation_bottlenecks,
               ground_truth_input: validation_ground_truth})

# validation_predictions now stores class labels predicted by model

# calculate precision, recall and F1 score
(validation_precision,
 validation_recall,
 validation_f1_score, _) = precision_recall_fscore_support(y_true=validation_ground_truth,
                                                           y_pred=validation_predictions,
                                                           average='micro')
# calculate confusion matrix
validation_confusion_matrix = confusion_matrix(y_true=validation_ground_truth,
                                               y_pred=validation_predictions)

最后,代码调用run_final_eval以根据测试数据评估经过训练的模型。在此函数中,已经定义了predictiontest_ground_truth,因此您只需要包括代码即可计算所需的指标:

test_accuracy, predictions = eval_session.run(
    [evaluation_step, prediction],
    feed_dict={
        bottleneck_input: test_bottlenecks,
        ground_truth_input: test_ground_truth
    })

# calculate precision, recall and F1 score
(test_precision,
 test_recall,
 test_f1_score, _) = precision_recall_fscore_support(y_true=test_ground_truth,
                                                     y_pred=predictions,
                                                     average='micro')
# calculate confusion matrix
test_confusion_matrix = confusion_matrix(y_true=test_ground_truth,
                                         y_pred=predictions)

请注意,提供的代码通过设置average='micro'来计算 global F1分数。 User Guide中描述了scikit-learn软件包支持的不同平均方法。