sklearn预测函数对“ ovr”多类问题的不一致

时间:2019-05-03 12:54:12

标签: scikit-learn svm multiclass-classification

对于多类问题,我发现SVM模型的预测功能不一致。我已经用SKlearn SVM.SVC函数训练了一个模型来解决多类预测问题(请参见下图)。

但是在某些情况下,当我使用预测函数的argmax进行预测时,预测函数会给我不同的结果。可以看到不一致之处接近决策边界。

当我直接使用OneVsRestClassifier时,这种矛盾消失了。 SVM.SVC类的预测函数是否进行了一些更正?为什么它与argmax预测不同?

以下是重现结果的代码:

import numpy as np
from sklearn import svm, datasets
from sklearn.multiclass import OneVsRestClassifier
from scipy.linalg import cho_solve, cho_factor

def create_data(n_samples, noise):
    # 4 gaussian blobs with different means and variances
    sample_per_cls = np.int(n_samples/4)
    sample_per_cls_rest = sample_per_cls + n_samples - 4*sample_per_cls #puts the rest of the samples into the last class

    x1 = np.random.multivariate_normal([20, 18], np.array([[2, 3], [3, 7]])*4*noise, sample_per_cls, 'warn')
    x2 = np.random.multivariate_normal([13, 27], np.array([[10, 3], [3, 2]])*4*noise, sample_per_cls, 'warn')
    x3 = np.random.multivariate_normal([9, 13], np.array([[6, 1], [1, 5]])*4*noise, sample_per_cls, 'warn')
    x4 = np.random.multivariate_normal([14, 20], np.array([[4, 0.2], [0.2, 7]])*4*noise, sample_per_cls_rest, 'warn')

    X = np.vstack([x1,x2,x3,x4])

    #define the labels for each class
    Y = np.empty([n_samples], dtype=np.int)
    Y[0:sample_per_cls] = 0
    Y[sample_per_cls:2*sample_per_cls] = 1
    Y[2*sample_per_cls:3*sample_per_cls] = 2
    Y[3*sample_per_cls:] = 3

    #shuffle the data set
    rand_int = np.arange(n_samples)
    np.random.shuffle(rand_int)
    X = X[rand_int]
    Y = Y[rand_int]    
    return X, Y

X, Y = create_data(n_samples=800, noise=0.15)
clf = svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000)
#the classifier below is consistent
#clf = OneVsRestClassifier(svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000))
clf.fit(X,Y)

Xs = np.linspace(np.min(X[:,0] - 1), np.max(X[:,0] + 1), 150)
Ys = np.linspace(np.min(X[:,1] - 1), np.max(X[:,1] + 1), 150)
XX, YY = np.meshgrid(Xs, Ys)
test_set = np.stack([XX, YY], axis=2).reshape(-1,2)

#prediction via argmax of the decision function
pred = np.argmax(clf.decision_function(test_set), axis=1)

#prediction with sklearn function
pred_1 = clf.predict(test_set)
diff = np.equal(pred, pred_1)
error = np.where(diff == False)[0]
print(error)

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 10]
plt.contourf(XX, YY, pred_1.reshape(XX.shape), alpha=0.5, cmap='seismic')
plt.colorbar()
plt.scatter(X[:,0], X[:,1], c=Y, s=20, marker='o', edgecolors='k')
plt.scatter(test_set[error, 0], test_set[error, 1], c=pred_1[error], s=120, marker='^', edgecolors='k')
plt.show()

三角形标记不一致的点: Triangles are the inconsistent points

0 个答案:

没有答案