SKLearn多类分类器

时间:2015-05-08 14:00:03

标签: python scikit-learn svm

我编写了以下代码来从文件导入数据向量并测试SVM分类器的性能(使用sklearn和python)。

然而,分类器性能低于任何其他分类器(例如NNet在测试数据上给出98%的准确度,但最多只能提供92%)。根据我的经验,SVM应该为这种数据产生更好的结果。

我可能做错了吗?

import numpy as np

def buildData(featureCols, testRatio):
    f = open("car-eval-data-1.csv")
    data = np.loadtxt(fname = f, delimiter = ',')

    X = data[:, :featureCols]  # select columns 0:featureCols-1
    y = data[:, featureCols]   # select column  featureCols 

    n_points = y.size
    print "Imported " + str(n_points) + " lines."

    ### split into train/test sets
    split = int((1-testRatio) * n_points)
    X_train = X[0:split,:]
    X_test  = X[split:,:]
    y_train = y[0:split]
    y_test  = y[split:]

    return X_train, y_train, X_test, y_test

def buildClassifier(features_train, labels_train):
    from sklearn import svm

    #clf = svm.SVC(kernel='linear',C=1.0, gamma=0.1)
    #clf = svm.SVC(kernel='poly', degree=3,C=1.0, gamma=0.1)
    clf = svm.SVC(kernel='rbf',C=1.0, gamma=0.1)
    clf.fit(features_train, labels_train)
    return clf

def checkAccuracy(clf, features, labels):
    from sklearn.metrics import accuracy_score

    pred = clf.predict(features)
    accuracy = accuracy_score(pred, labels)
    return accuracy

features_train, labels_train, features_test, labels_test = buildData(6, 0.3)
clf           = buildClassifier(features_train, labels_train)
trainAccuracy = checkAccuracy(clf, features_train, labels_train)
testAccuracy  = checkAccuracy(clf, features_test, labels_test)
print "Training Items: " + str(labels_train.size) + ", Test Items: " + str(labels_test.size)
print "Training Accuracy: " + str(trainAccuracy)
print "Test Accuracy: " + str(testAccuracy)

i = 0
while i < labels_test.size:
  pred = clf.predict(features_test[i])
  print "F(" + str(i) + ") : " + str(features_test[i]) + " label= " + str(labels_test[i]) + " pred= " + str(pred);
  i = i + 1

如果默认不这样做,怎么可能进行多类分类?

P.S。我的数据格式如下(最后一列是类):

2,2,2,2,2,1,0
2,2,2,2,1,2,0
0,2,2,5,2,2,3
2,2,2,4,2,2,1
2,2,2,4,2,0,0
2,2,2,4,2,1,1
2,2,2,4,1,2,1
0,2,2,5,2,2,3

1 个答案:

答案 0 :(得分:2)

我在很长一段时间后发现了这个问题并且我发布了它,以防有人需要它。

问题是数据导入功能不会改变数据。如果数据以某种方式排序,则存在使用某些数据训练分类器并使用完全不同的数据对其进行测试的风险。在NNet案例中,使用了Matlab自动混洗输入数据。

def buildData(filename, featureCols, testRatio):
f = open(filename)
data = np.loadtxt(fname = f, delimiter = ',')
np.random.shuffle(data)    # randomize the order

X = data[:, :featureCols]  # select columns 0:featureCols-1
y = data[:, featureCols]   # select column  featureCols 

n_points = y.size
print "Imported " + str(n_points) + " lines."

### split into train/test sets
split = int((1-testRatio) * n_points)
X_train = X[0:split,:]
X_test  = X[split:,:]
y_train = y[0:split]
y_test  = y[split:]

return X_train, y_train, X_test, y_test