精度为0.995的CNN在实施中效果不佳

时间:2019-06-27 12:58:18

标签: python github jupyter-notebook conv-neural-network

为了简短起见,我针对每个类别中的数据均相同的二元分类训练了模型,以免出现类别失衡。在带有相应标签的10000张图像上训练模型,并在带有相应标签的6000张图像上验证模型。

结果是一个具有0.995精度的模型,这意味着该模型的实现将对时间的正确类进行分类0.995。 (模型没有一直选择A类,并且由于没有类不平衡,所以每次都正确选择0.995)

但是,事实并非如此。而且,数据已经重新整理,因此该模型也不会在前5000张图像中猜测A类,而在其余图像中都猜测B类,从而获得0.995的准确性。

我注意到的完整代码,问题和内容在我的github上:

enter image description here

可以免费下载和使用该模型,以查看飞扬的小鸟机器人的结果。


编辑1:总图像中有8 000张是原始图像,其余8 000张图像如下所述

以下代码段显示了对原始图像的增强

# select opp_line which stage =='7 - Deliver & Validate'  and oplin_status =='Pending'
DF_BR8 = df.filter(df.stage.contains("7 - Deliver")).select('opp_id__reference', 'oplin_status', 'stage', 'std_amount', 'std_line_amount')

DF_BR8_1 = DF_BR8.groupby('opp_id__reference', 'std_amount', 'oplin_status').agg({'std_line_amount': 'sum'}).withColumnRenamed('sum(std_line_amount)','sum_column')

DF_res = DF_BR8_1.filter(DF_BR8_1.oplin_status.contains("Pending"))
DF_res1 =DF_res.filter(DF_res.sum_column <= 0.3*DF_BR8_1.std_amount)

编辑2:以下代码用于生成原始数据集(可在github上找到)

datagen = ImageDataGenerator(featurewise_center=True, samplewise_center=True, 
                             featurewise_std_normalization=True, samplewise_std_normalization=True, 
                             zca_whitening=True, zca_epsilon=1e-06)

1 个答案:

答案 0 :(得分:0)

共有三件事,分别是准确性,准确性和召回率。准确性通常不是分类器的首选性能指标,尤其是在处理偏斜的数据集时(即某些类别的频率比其他类别的频率高得多)。

精度:

    TP
_________
TP  +  FP

回忆:

    TP
_________
TP  +  FN

在哪里, TP =真阳性

FP =假阴性

评估分类器性能的一种更好的方法是查看混淆矩阵。通常的想法是计算A类实例被分类为B类的次数。要计算混淆矩阵,首先需要具有一组预测,以便可以将它们与实际目标进行比较。

An illustrated confusion matrix for digit 5 classifier.

您可以使用scikitkearn库中的cross_val_predict()函数来计算预测目标。

from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

现在,您可以使用confusion_matrix()函数获取混淆矩阵了。 只需将目标类别(y_train_5)和预测类别(y_train_pred)传递给它即可:

from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)

Scikit-Learn提供了多种功能来计算分类器指标,包括精度和召回率:

from sklearn.metrics import precision_score, recall_score

precision_score(y_train_5, y_train_pred)
recall_score(y_train_5, y_train_pred)

希望这会有所帮助:)