如何用Sklearn得到线性判别分析中的边界线方程

时间:2016-04-20 13:31:53

标签: python machine-learning scikit-learn

我将一些数据分为两类,使用sklearn的LinearDiscriminantAnalysis分类器,效果很好,所以我这样做了:

from sklearn.cross_validation import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25) # 25% of the dataset are not used for the training
clf = LDA()
clf.fit(x_train, y_train)

然后我设法用它进行预测,这很好。

但是,所有这些都在ipython笔记本中,我想在其他地方使用分类器。我已经看到了使用泡菜和joblib的可能性,但因为我只有2组和2个特征,所以我虽然我可以只是得到边界线的等式,然后检查是否给定点高于或低于该行以告知它属于哪个组。

根据我的理解,这条线与投影线正交,并通过簇的平均值。我想我用np.mean(clf.means_, axis=0)得到了集群的意思。

但在这里,我仍然坚持如何使用clf.coef_clf.intercept_等所有属性来查找投影线的等式。

所以,我的问题是如何根据我的分类器得到边界线方程。

我也可能没有正确理解LDA,我很乐意有更多的解释。

由于

1 个答案:

答案 0 :(得分:3)

决策边界只是用

给出的行
np.dot(clf.coef_, x) - clf.intercept_ = 0

(直到拦截的符号,这取决于实施可能被翻转),因为这是决策功能的符号翻转的地方。