我如何评估StratifiedKFold模型

时间:2019-05-14 10:15:53

标签: python numpy tensorflow keras

    import numpy as np
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.wrappers.scikit_learn import KerasClassifier 
    from sklearn.model_selection import StratifiedKFold 
    from sklearn.model_selection import cross_val_score
    from sklearn.model_selection import cross_val_predict   

    x_train = dataset[0:700,:-1]
    y_train = dataset[0:700,-1]
    x_test = dataset[700:,:-1]
    y_test = dataset[700:,-1]

    def create_model():
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

    model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=64)
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=seed) 

    scores = cross_val_score(model, x_train, y_train, cv=skf)
    predictions = cross_val_predict(model, x_test, y_test, cv=skf)

我想通过StratifiedKFold训练[x_train],[y_train] 并通过[x_test],[y_test]进行评估 我能怎么做? 我尝试了cross_val_predict。但我认为这不合适。

2 个答案:

答案 0 :(得分:0)

要以分层方式在训练和测试之间进行划分,可以使用:

from sklearn.model_selection import train_test_split
dataset_train, dataset_test = train_test_split(dataset,
                                                stratify=dataset[:,-1], 
                                                test_size=0.2)

#split both datasets into X,y

检查:

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

Stratified Train/Test-split in scikit-learn

答案 1 :(得分:0)

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=seed)
accuracy=[]
for train in skf.split(x_train, y_train):
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

这个怎么样?这是工作,但我不知道这是正确的。

相关问题