对于多类问题,我发现SVM模型的预测功能不一致。我已经用SKlearn SVM.SVC
函数训练了一个模型来解决多类预测问题(请参见下图)。
但是在某些情况下,当我使用预测函数的argmax进行预测时,预测函数会给我不同的结果。可以看到不一致之处接近决策边界。
当我直接使用OneVsRestClassifier
时,这种矛盾消失了。 SVM.SVC
类的预测函数是否进行了一些更正?为什么它与argmax预测不同?
以下是重现结果的代码:
import numpy as np
from sklearn import svm, datasets
from sklearn.multiclass import OneVsRestClassifier
from scipy.linalg import cho_solve, cho_factor
def create_data(n_samples, noise):
# 4 gaussian blobs with different means and variances
sample_per_cls = np.int(n_samples/4)
sample_per_cls_rest = sample_per_cls + n_samples - 4*sample_per_cls #puts the rest of the samples into the last class
x1 = np.random.multivariate_normal([20, 18], np.array([[2, 3], [3, 7]])*4*noise, sample_per_cls, 'warn')
x2 = np.random.multivariate_normal([13, 27], np.array([[10, 3], [3, 2]])*4*noise, sample_per_cls, 'warn')
x3 = np.random.multivariate_normal([9, 13], np.array([[6, 1], [1, 5]])*4*noise, sample_per_cls, 'warn')
x4 = np.random.multivariate_normal([14, 20], np.array([[4, 0.2], [0.2, 7]])*4*noise, sample_per_cls_rest, 'warn')
X = np.vstack([x1,x2,x3,x4])
#define the labels for each class
Y = np.empty([n_samples], dtype=np.int)
Y[0:sample_per_cls] = 0
Y[sample_per_cls:2*sample_per_cls] = 1
Y[2*sample_per_cls:3*sample_per_cls] = 2
Y[3*sample_per_cls:] = 3
#shuffle the data set
rand_int = np.arange(n_samples)
np.random.shuffle(rand_int)
X = X[rand_int]
Y = Y[rand_int]
return X, Y
X, Y = create_data(n_samples=800, noise=0.15)
clf = svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000)
#the classifier below is consistent
#clf = OneVsRestClassifier(svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000))
clf.fit(X,Y)
Xs = np.linspace(np.min(X[:,0] - 1), np.max(X[:,0] + 1), 150)
Ys = np.linspace(np.min(X[:,1] - 1), np.max(X[:,1] + 1), 150)
XX, YY = np.meshgrid(Xs, Ys)
test_set = np.stack([XX, YY], axis=2).reshape(-1,2)
#prediction via argmax of the decision function
pred = np.argmax(clf.decision_function(test_set), axis=1)
#prediction with sklearn function
pred_1 = clf.predict(test_set)
diff = np.equal(pred, pred_1)
error = np.where(diff == False)[0]
print(error)
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 10]
plt.contourf(XX, YY, pred_1.reshape(XX.shape), alpha=0.5, cmap='seismic')
plt.colorbar()
plt.scatter(X[:,0], X[:,1], c=Y, s=20, marker='o', edgecolors='k')
plt.scatter(test_set[error, 0], test_set[error, 1], c=pred_1[error], s=120, marker='^', edgecolors='k')
plt.show()