如何在KNN python sklearn中进行N交叉验证?

时间:2016-11-26 04:46:50

标签: python machine-learning scikit-learn knn

我是机器学习的新手,我试图在KDD Cup 1999数据集上进行KNN算法。我设法创建了分类器并预测数据集,结果准确率大约为92%。

但我观察到我的准确性可能不准确,因为测试和训练数据集是静态设置的,并且可能因不同的数据集而有所不同。

那我怎么做N Cross验证?

以下是我的代码:

import pandas
from time import time
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score
#TRAINING
col_names = ["duration","protocol_type","service","flag","src_bytes",
    "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
    "logged_in","num_compromised","root_shell","su_attempted","num_root",
    "num_file_creations","num_shells","num_access_files","num_outbound_cmds",
    "is_host_login","is_guest_login","count","srv_count","serror_rate",
    "srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
    "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
    "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
    "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
    "dst_host_rerror_rate","dst_host_srv_rerror_rate","label"]
kdd_data_10percent = pandas.read_csv("data/kdd_10pc", header=None, names = col_names)

num_features = [
    "duration","src_bytes",
    "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
    "logged_in","num_compromised","root_shell","su_attempted","num_root",
    "num_file_creations","num_shells","num_access_files","num_outbound_cmds",
    "is_host_login","is_guest_login","count","srv_count","serror_rate",
    "srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
    "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
    "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
    "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
    "dst_host_rerror_rate","dst_host_srv_rerror_rate"
]
features = kdd_data_10percent[num_features].astype(float)


#classifying all labels not "normal" as attack
labels = kdd_data_10percent['label'].copy()
labels[labels!='normal.'] = 'attack.'
print labels.value_counts()

#TODO: Normalising of data
#TODO: Principal Component Analysis - Data reduction

clf = KNeighborsClassifier(n_neighbors = 5, algorithm = 'ball_tree', leaf_size=500)
t0 = time()
clf.fit(features,labels)
tt = time()-t0
print "Classifier trained in {} seconds".format(round(tt,3))

#TESTING
kdd_data_test = pandas.read_csv("data/corrected", header=None, names = col_names)
kdd_data_test['label'][kdd_data_test['label']!='normal.'] = 'attack.'
kdd_data_test[num_features] = kdd_data_test[num_features].astype(float)
features_train, features_test, labels_train, labels_test = train_test_split(
    kdd_data_test[num_features], 
    kdd_data_test['label'], 
    test_size=0.1, 
    random_state=42)
t0 = time()
pred = clf.predict(features_test)
tt = time() - t0
print "Predicted in {} seconds".format(round(tt,3))

acc = accuracy_score(pred, labels_test)
print "R squared is {}.".format(round(acc,4))

感谢任何指导!非常感谢你!

1 个答案:

答案 0 :(得分:2)

K-fold cross validation

var names = new List<string> {"mercedes", "mazda", "bmw", "fiat", "ferrari"};

// updating existing list
names[0] = "ford";

// before calling ToList directly
var startingWith_M = names.Where(x => x.StartsWith("m"));

var startingWith_F = names.Where(x => x.StartsWith("f"));

print( startingWith_M.ToList() );
print( startingWith_F.ToList() );

<强> Leave One Out cross validation

import numpy as np
from sklearn.model_selection import KFold

X = ["a", "b", "c", "d"]
kf = KFold(n_splits=2)
for train, test in kf.split(X):
    print("%s %s" % (train, test))

[2 3] [0 1] // these are indices of X
[0 1] [2 3]

<强> Leave P-out Cross Validation

from sklearn.model_selection import LeaveOneOut

X = [1, 2, 3, 4]
loo = LeaveOneOut()
for train, test in loo.split(X):
    print("%s %s" % (train, test))

[1 2 3] [0] // these are indices of X
[0 2 3] [1]
[0 1 3] [2]
[0 1 2] [3]