如何使用matplotlib绘制回归的决策边界?

时间:2013-11-18 10:54:31

标签: python matplotlib contour scatter-plot logistic-regression

如何将逻辑回归结果的等值线图添加到散点图?我想要彩色0/1区域,它描绘了分类器的决策边界。

import pandas as pd
import numpy as np
import pylab as pl
import statsmodels.api as sm

# Build X, Y from file
f = open('ex2data2.txt')
lines = f.readlines()
x1 = []
x2 = []
y = []
for line in lines:
    line = line.replace("\n", "")
    vals = line.split(",")
    x1.append(float(vals[0]))
    x2.append(float(vals[1]))
    y.append(int(vals[2]))

x1 = np.array(x1)
x2 = np.array(x2)
y = np.array(y)

x = np.vstack([x1, x2]).T

# Scatter plot 0/1s
pos_mask = y == 1
neg_mask = y == 0
pos_x1 = x1[pos_mask]
neg_x1 = x1[neg_mask]
pos_x2 = x2[pos_mask]
neg_x2 = x2[neg_mask]
pl.clf()
pl.scatter(pos_x1, pos_x2, c='r')
pl.scatter(neg_x1, neg_x2, c='g')

# Run logistic regression
logit = sm.Logit(y, x)
result = logit.fit()
result.params
result.predict([1.0, 1.0])

# Now I want to add a countour for 0/1 regression results to the scatter plot.

1 个答案:

答案 0 :(得分:5)

我会尝试回答,但您必须了解一些关于我的答案的假设,这些假设可能适用于您的代码,也可能不适用于您的代码:

我的进口商品:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

X包含您的功能,如下所示:

print type(X)
<type 'numpy.ndarray'>

显示为102,2:

print X
[[-13.15490196 -23.        ]
[-22.95490196 -25.        ]
[-12.75490196  -8.        ]
[  0.14509804  -6.        ]
.
.
.

ytrain包含基本事实,在这种情况下是布尔值,但你可以做0/1同样的事情。

print type(ytrain)
<type 'numpy.ndarray'>

是51,

print (train)
[False False False False  True  True  True  True  True  True False  True
False  True  True  True False False False  True  True  True  True  True
False False False False  True  True  True  True  True  True False  True
False  True  True  True False False False False False  True  True  True
False  True False]

最后clf包含您的模型,在我的例子中是一个合适的模型 另外我使用scikit学习的LogisticRegression,这取决于我的clf.predict_proba 提供构建标签和轮廓所需的信息。我不熟悉 你正在使用的确切包装,但请记住这一点。

# evenly sampled points
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 50),
                     np.linspace(y_min, y_max, 50))
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())

#plot background colors
ax = plt.gca()
Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
Z = Z.reshape(xx.shape)
cs = ax.contourf(xx, yy, Z, cmap='RdBu', alpha=.5)
cs2 = ax.contour(xx, yy, Z, cmap='RdBu', alpha=.5)
plt.clabel(cs2, fmt = '%2.1f', colors = 'k', fontsize=14)

# Plot the points
ax.plot(Xtrain[ytrain == 0, 0], Xtrain[ytrain == 0, 1], 'ro', label='Class 1')
ax.plot(Xtrain[ytrain == 1, 0], Xtrain[ytrain == 1, 1], 'bo', label='Class 2')

# make legend
plt.legend(loc='upper left', scatterpoints=1, numpoints=1)

您的结果将如下所示:

enter image description here

相关问题