通过交叉验证计算AUC 95%CI(Python,sklearn)

时间:2020-05-18 07:32:47

标签: python logistic-regression cross-validation confidence-interval auc

我正在寻找正确的方法来从我的5倍CV中计算AUC 95%CI。

n =我的训练数据集中的81

因此,如果我应用5倍CV,其平均值约为。在测试组中,每折n = 16。

下面的我的Python代码。

folds = 5
seed = 42

# Grid Search
fit_intercept=[True, False]
C = [np.arange(1,41,1)]
penalty = ['l1', 'l2']
params = dict(C=C, fit_intercept = fit_intercept, penalty = penalty)

logreg = LogisticRegression(random_state = seed)

logreg_grid = GridSearchCV(logreg, param_grid = params , cv=folds, scoring='roc_auc',  iid='False')

# fit the grid with data
logreg_grid.fit(X_train, y_train)

# fit best estimator
logreg = logreg_grid.best_estimator_

# Calculate AUC in 5-fold Stratified CV
logreg_scores = cross_val_score(logreg, X_train, y_train, cv=folds, scoring='roc_auc')
print('LogReg:',logreg_scores.mean())

# LogReg Scores: [0.95714286, 0.85, 0.98333333, 0.85, 0.56666667]  
# Mean: 0.8414285714285714````

#AUC from LogReg = 0.8414

#Three ways I have tried to calculate the 95 % CI:

#LogReg Scores: [0.95714286, 0.85, 0.98333333, 0.85, 0.56666667]  
# Mean: 0.8414285714285714


                    ### First try ###
import statsmodels.stats.api as sms
conf = sms.DescrStatsW(logreg_scores).tconfint_mean(.05)
print(conf)

#Out: Lower 0.636, Upper: 1.047

                    ### Second Try ###
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2, n-1)
    return m, m-h, m+h


mean_confidence_interval(logreg_scores, confidence=0.95)

#Out: Mid: 0.84, Lower: 0.64, Upper: 1.05)

                      ### Third ###
# interval = t * np.sqrt( (AUC * (1 - AUC)) / n)
# n = 16 (validation set), because the mean in of alle 5 folds is 16 aof my n = 81
# t = 2.120 (Source: https://www.sjsu.edu/faculty/gerstman/StatPrimer/t-table.pdf)

interval = 2.120 * np.sqrt( (0.8414285714285714 * (1 - 0.8414285714285714)) / 16)
print((.84 + interval)*100)
print(.84)
print((.84 - interval)*100)
print(interval)

# Output: Lower: 64.64 , Mid: 0.84, Upper: 103.36 , Interval: 0.194

我的问题:所有结果看起来都很相似。但是,我做错了,因为我不知道AUC如何大于1.0?

感谢您的采访,我期待您的答复。

欢呼米沙

2 个答案:

答案 0 :(得分:1)

我不确定它是否能解决您的问题,但我想这是因为您正在对一个极小的样本量(n = 5)应用t检验。预期会有很大的差异,这就是为什么您的情况下均值+ SD> 1。请注意,您的所有三种方法都是基于t检验的。

要获得足够的比较数,您可能要尝试1)使用不同子类的多个重复CV或2)bootstrappin。有关简历的一些有用讨论:CV

答案 1 :(得分:1)

这对田林河很有帮助!谢谢。

我是这样实现的:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:background="#ECEAEA"
    android:fitsSystemWindows="true"
    android:focusableInTouchMode="true"
    android:layoutDirection="ltr"
    android:paddingLeft="5dp"
    android:paddingTop="5dp"
    android:paddingRight="5dp">

    <ScrollView
        android:id="@+id/scroll_view_post_activity"
        android:layout_width="match_parent"
        android:layout_height="wrap_content">
        <RelativeLayout
            android:layout_width="match_parent"
            android:layout_height="wrap_content">


            <LinearLayout
                android:id="@+id/linear_layout_head"
                android:layout_width="match_parent"
                android:layout_height="wrap_content"
                android:orientation="vertical">

                <RelativeLayout
                    android:layout_width="match_parent"
                    android:layout_height="wrap_content">

                    <android.support.v7.widget.CardView
                        android:id="@+id/card_view_category_status"
                        android:layout_width="match_parent"
                        android:layout_height="wrap_content"
                        android:layout_marginLeft="5dp"
                        android:layout_marginRight="5dp"
                        android:layout_marginBottom="5dp"
                        app:cardCornerRadius="5dp"
                        app:cardElevation="1dp"
                        >

                        <LinearLayout
                            android:layout_width="match_parent"
                            android:layout_height="wrap_content"
                            android:layout_gravity="bottom"
                            android:orientation="horizontal"
                            android:padding="5dp"
                            android:weightSum="12"
                            >
                            <de.hdodenhof.circleimageview.CircleImageView
                                android:id="@+id/video_detail_category_thumb"
                                android:layout_width="0dp"
                                android:layout_height="50dp"
                                android:layout_gravity="center_vertical"
                                android:layout_weight="2"
                                android:src="@drawable/logo"
                                app:civ_border_color="@color/white"
                                app:civ_border_width="2dp"/>

                            <LinearLayout
                                android:layout_width="0dp"
                                android:layout_height="wrap_content"
                                android:layout_gravity="center"
                                android:layout_marginLeft="10dp"
                                android:layout_marginRight="10dp"
                                android:layout_weight="10"
                                android:orientation="vertical">

                                <TextView
                                    android:id="@+id/video_detail_title"
                                    android:layout_width="match_parent"
                                    android:layout_height="wrap_content"
                                    android:maxLines="2"
                                    android:text="@string/app_name"
                                    android:textColor="@color/color_black"
                                    android:textSize="18sp"
                                    android:textStyle="bold"
                                    />

                                <TextView
                                    android:id="@+id/video_detail_category_title"
                                    android:layout_width="match_parent"
                                    android:layout_height="wrap_content"
                                    android:lines="1"
                                    android:text="@string/app_name"
                                    android:textColor="@color/color_black"
                                    android:textSize="15sp"
                                    />

                            </LinearLayout>


                        </LinearLayout>


                    </android.support.v7.widget.CardView>

                </RelativeLayout>

                <!-- this is the native ad contaainer, I will inflate this in activity -->
                <RelativeLayout
                    android:id="@+id/container"
                    android:layout_width="match_parent"
                    android:layout_height="wrap_content"
                    android:layout_alignParentBottom="true" />

                <android.support.v7.widget.CardView
                    android:id="@+id/relative_layout_content2"
                    android:layout_width="match_parent"
                    android:layout_height="wrap_content"
                    android:layout_below="@id/linear_layout_head"
                    android:layout_marginLeft="5dp"
                    android:layout_marginRight="5dp"
                    android:layout_marginBottom="5dp"
                    app:cardCornerRadius="5dp"
                    app:cardElevation="1dp">

                    <LinearLayout
                        android:layout_width="match_parent"
                        android:layout_height="wrap_content"
                        android:layout_gravity="bottom"
                        android:orientation="horizontal"
                        android:padding="5dp"
                        android:weightSum="12"
                        >
                        <de.hdodenhof.circleimageview.CircleImageView
                            android:id="@+id/video_detail_category_thumb3"
                            android:layout_width="0dp"
                            android:layout_height="50dp"
                            android:layout_gravity="center_vertical"
                            android:layout_weight="2"
                            android:src="@drawable/logo"
                            app:civ_border_color="@color/white"
                            app:civ_border_width="2dp"/>

                        <LinearLayout
                            android:layout_width="0dp"
                            android:layout_height="wrap_content"
                            android:layout_gravity="center"
                            android:layout_marginLeft="10dp"
                            android:layout_marginRight="10dp"
                            android:layout_weight="10"
                            android:orientation="vertical">

                            <TextView
                                android:id="@+id/video_detail_title2"
                                android:layout_width="match_parent"
                                android:layout_height="wrap_content"
                                android:maxLines="2"
                                android:text="@string/app_name"
                                android:textColor="@color/color_black"
                                android:textSize="18sp"
                                android:textStyle="bold"
                                />

                            <TextView
                                android:id="@+id/video_detail_category_title2"
                                android:layout_width="match_parent"
                                android:layout_height="wrap_content"
                                android:lines="1"
                                android:text="@string/app_name"
                                android:textColor="@color/color_black"
                                android:textSize="15sp"
                                />

                        </LinearLayout>

                    </LinearLayout>
                </android.support.v7.widget.CardView>
            </LinearLayout>
            <android.support.v7.widget.CardView
                android:layout_below="@id/linear_layout_head"

                android:id="@+id/relative_layout_content"
                android:layout_width="match_parent"
                android:layout_height="wrap_content"
                android:layout_marginLeft="5dp"
                android:layout_marginRight="5dp"
                android:layout_marginBottom="5dp"
                app:cardCornerRadius="5dp"
                app:cardElevation="1dp">

                <LinearLayout
                    android:layout_width="match_parent"
                    android:layout_height="wrap_content"
                    android:layout_gravity="bottom"
                    android:orientation="horizontal"
                    android:padding="5dp"
                    android:weightSum="12"
                    >
                    <de.hdodenhof.circleimageview.CircleImageView
                        android:id="@+id/video_detail_category_thumb2"
                        android:layout_width="0dp"
                        android:layout_height="50dp"
                        android:layout_gravity="center_vertical"
                        android:layout_weight="2"
                        android:src="@drawable/logo"
                        app:civ_border_color="@color/white"
                        app:civ_border_width="2dp"/>

                    <LinearLayout
                        android:layout_width="0dp"
                        android:layout_height="wrap_content"
                        android:layout_gravity="center"
                        android:layout_marginLeft="10dp"
                        android:layout_marginRight="10dp"
                        android:layout_weight="10"
                        android:orientation="vertical">

                        <TextView
                            android:id="@+id/video_detail_title3"
                            android:layout_width="match_parent"
                            android:layout_height="wrap_content"
                            android:maxLines="2"
                            android:text="@string/app_name"
                            android:textColor="@color/color_black"
                            android:textSize="18sp"
                            android:textStyle="bold"
                            />

                        <TextView
                            android:id="@+id/video_detail_category_title3"
                            android:layout_width="match_parent"
                            android:layout_height="wrap_content"
                            android:lines="1"
                            android:text="@string/app_name"
                            android:textColor="@color/color_black"
                            android:textSize="15sp"
                            />

                    </LinearLayout>

                </LinearLayout>
            </android.support.v7.widget.CardView>

        </RelativeLayout>
    </ScrollView>
</RelativeLayout>

输出很好,因为现在我有500个AUC。 >>>(0.8014285714285716、0.7921705464185262、0.810686596438617)

但是我该如何针对概率实现呢?

from sklearn.model_selection import RepeatedStratifiedKFold

cv = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 100, random_state = seed)

logreg_scores = cross_val_score(logreg, X_train, y_train, cv=cv, scoring='roc_auc')
print('LogReg:',logreg_scores.mean())


import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2, n-1)
    return m, m-h, m+h

mean_confidence_interval(logreg_scores, confidence=0.95)

如果我使用上面的代码,则会引发错误:“ cross_val_predict仅适用于分区”