您好我正在尝试重现Scikit's example for plotting decision boundaries Voting Classifiers。
分类部分相当直接,并且在单个图中绘制几个图的简洁方法是有趣的。但是,我无法改变着色方案。
这是直截了当的分类部分:
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import VotingClassifier
# Loading some example data
iris = datasets.load_iris()
X = iris.data[:, [0, 2]]
y = iris.target
# Training classifiers
clf1 = DecisionTreeClassifier(max_depth=4)
clf2 = KNeighborsClassifier(n_neighbors=7)
clf3 = SVC(kernel='rbf', probability=True)
eclf = VotingClassifier(estimators=[('dt', clf1), ('knn', clf2),
('svc', clf3)],
voting='soft', weights=[2, 1, 2])
clf1.fit(X, y)
clf2.fit(X, y)
clf3.fit(X, y)
eclf.fit(X, y)
该示例使用以下代码创建图:
# Plotting decision regions
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
f, axarr = plt.subplots(2, 2, sharex='col', sharey='row', figsize=(10, 8))
for idx, clf, tt in zip(product([0, 1], [0, 1]),
[clf1, clf2, clf3, eclf],
['Decision Tree (depth=4)', 'KNN (k=7)',
'Kernel SVM', 'Soft Voting']):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
axarr[idx[0], idx[1]].contourf(xx, yy, Z, alpha=0.4)
axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y,
s=20, edgecolor='k')
axarr[idx[0], idx[1]].set_title(tt)
plt.show()
似乎matplotlib以某种方式使用默认着色方案。有没有办法传递其他颜色?我试图用c=y
(例如c = ['y', 'b']
)来解决问题,但这并不能解决问题。
我想改变背景颜色和散射着色。有什么想法吗?
答案 0 :(得分:2)
根据各个图中y
和Z
的值选择颜色。 y
具有与点数相同的条目,并且具有3个唯一值。 Z
也有3个级别。它们根据matplotlib进行颜色映射
colormaps
您可以选择不同的色彩映射表,例如cmap="brg"
:
axarr[idx].contourf(xx, yy, Z, alpha=0.4, cmap="brg")
axarr[idx].scatter(X[:, 0], X[:, 1], c=y, cmap="brg",
s=20, edgecolor='w')
完整代码:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
X = iris.data[:, [0, 2]]
y = iris.target
clf1 = DecisionTreeClassifier(max_depth=4)
clf2 = KNeighborsClassifier(n_neighbors=7)
clf1.fit(X, y)
clf2.fit(X, y)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
f, axarr = plt.subplots(1,2, sharex='col', sharey='row', figsize=(5,3))
for idx, clf, tt in zip([0, 1],[clf1, clf2],
['Decision Tree (depth=4)', 'KNN (k=7)']):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
axarr[idx].contourf(xx, yy, Z, alpha=0.4, cmap="brg")
axarr[idx].scatter(X[:, 0], X[:, 1], c=y, cmap="brg",
s=20, edgecolor='w')
axarr[idx].set_title(tt)
plt.show()
您也可以创建自定义色彩映射表。例如。使用金,深红和靛蓝作为颜色,
import matplotlib.colors
cmap = matplotlib.colors.ListedColormap(["gold", "crimson", "indigo"])
axarr[idx].contourf(xx, yy, Z, alpha=0.4, cmap=cmap)
axarr[idx].scatter(X[:, 0], X[:, 1], c=y, cmap=cmap,
s=20, edgecolor='w')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
X = iris.data[:, [0, 2]]
y = iris.target
clf1 = DecisionTreeClassifier(max_depth=4)
clf2 = KNeighborsClassifier(n_neighbors=7)
clf1.fit(X, y)
clf2.fit(X, y)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
f, axarr = plt.subplots(1,2, sharex='col', sharey='row', figsize=(5,3))
cmap = matplotlib.colors.ListedColormap(["gold", "crimson", "indigo"])
for idx, clf, tt in zip([0, 1],[clf1, clf2],
['Decision Tree (depth=4)', 'KNN (k=7)']):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
axarr[idx].contourf(xx, yy, Z, alpha=0.4, cmap=cmap)
axarr[idx].scatter(X[:, 0], X[:, 1], c=y, cmap=cmap,
s=20, edgecolor='w')
axarr[idx].set_title(tt)
plt.show()