我正在尝试绘制Keras模型预测的决策图边界。但是,生成的边界似乎不正确。
这是我的模特
def base():
model = Sequential()
model.add(Dense(5,activation = 'relu', input_dim = 2))
model.add(Dense(2,activation = 'relu'))
model.add(Dense(1,activation = 'sigmoid'))
model.compile(optimizer = optimizers.SGD(lr=0.0007, momentum=0.0, decay=0.0), loss = 'binary_crossentropy', metrics= ['accuracy'])
return model
model = base()
history = model.fit(train_X,train_Y, epochs = 10000, batch_size =64, verbose = 2)
这是我的绘图函数(取自here)
def plot_decision_boundary(X, y, model, steps=1000, cmap='Paired'):
"""
Function to plot the decision boundary and data points of a model.
Data points are colored based on their actual label.
"""
cmap = get_cmap(cmap)
# Define region of interest by data limits
xmin, xmax = X[:,0].min() - 1, X[:,0].max() + 1
ymin, ymax = X[:,1].min() - 1, X[:,1].max() + 1
steps = 1000
x_span = linspace(xmin, xmax, steps)
y_span = linspace(ymin, ymax, steps)
xx, yy = meshgrid(x_span, y_span)
# Make predictions across region of interest
labels = model.predict(c_[xx.ravel(), yy.ravel()])
# Plot decision boundary in region of interest
z = labels.reshape(xx.shape)
fig, ax = subplots()
ax.contourf(xx, yy, z, cmap=cmap, alpha=0.5)
# Get predicted labels on training data and plot
train_labels = model.predict(X)
ax.scatter(X[:,0], X[:,1], c=y.ravel(), cmap=cmap, lw=0)
return fig, ax
plot_decision_boundary(train_X,train_Y, model, cmap = 'RdBu')
我得到了这样的情节
显然,这是对绘图决策边界的非常有缺陷的描述(由于存在如此多的边界,因此根本无法提供信息)。有人可以指出我的错误吗?
答案 0 :(得分:1)
由于概率是从0到1的连续值,所以轮廓越来越多。
如果可视化仅限于2类(输出为2D softmax向量),则可以使用此简单代码
def plotModelOut(x,y,model):
'''
x,y: 2D MeshGrid input
model: Keras Model API Object
'''
grid = np.stack((x,y))
grid = grid.T.reshape(-1,2)
outs = model.predict(grid)
y1 = outs.T[0].reshape(x.shape[0],x.shape[0])
plt.contourf(x,y,y1)
plt.show()
这将给出轮廓(一个以上),如果您想要一条轮廓线,则可以执行以下操作
您可以对从model.predict
输出的概率进行阈值处理,并显示一条轮廓线。
例如,
import numpy as np
from matplotlib import pyplot as plt
a = np.linspace(-5, 5, 100)
xx, yy = np.meshgrid(a,a)
z = xx**2 + yy**2
# z = z > 5 (Threshold value)
plt.contourf(xx, yy, z,)
plt.show()
在未评论阈值的情况下,我们获得了2张图片
具有连续值的多个轮廓
z为阈值时的单个轮廓(z = z> 5)
类似的方法可以在输出softmax向量上使用类似的方法
label = label > 0.5
有关可视化代码的更多信息,请参见IITM CVI Blog