我是Dash/Plot.ly
的新手,目前,我正在尝试使用matplotlib
复制以下图表(由Dash
制作):
我这样做的尝试是创建一种方法来生成单个图形:
def serve_prediction_plot(model, title, X, X_proj, y, y_proc, train_idx, test_idx, Z, xx, yy, x0, y0, d):
# Get train and test score from model
train_score = cross_val_score(model, X[train_idx], y_proc[train_idx]).mean()
test_score = model.score(X[test_idx], y_proc[test_idx])
# Colorscale
bright_cscale = [[0, '#FF0000'], [1, '#0000FF']]
colorscale_zip = zip(np.arange(0, 1.01, 1 / 8), cl.scales['9']['div']['RdBu'])
cscale = list(map(list, colorscale_zip))
axis_template = dict(
showgrid=False,
zeroline=False,
linecolor='white',
showticklabels=False,
ticks=''
)
layout = dict(
title=title,
xaxis=axis_template,
yaxis=axis_template,
showlegend=False,
hovermode='closest',
autosize=False,
margin=dict(l=0, r=0, t=30, b=0)
)
# Plot the prediction contour of the models
Z = Z.reshape(xx.shape)
print(Z.shape)
trace0 = go.Heatmap(
z=Z,
hoverinfo='none',
showscale=False,
colorscale=cscale,
x0=x0,
y0=y0,
dx=d,
dy=d
)
# Plot Training Data
trace1 = go.Scatter(
x=X_proj[train_idx, 0],
y=X_proj[train_idx, 1],
mode='markers',
name='Training Data (accuracy={:.3f})'.format(train_score),
text=y[train_idx],
marker=dict(
size=10,
color=y_proc[train_idx],
colorscale=bright_cscale,
line=dict(
width=1
)
)
)
# Plot Test Data
trace2 = go.Scatter(
x=X_proj[test_idx, 0],
y=X_proj[test_idx, 1],
mode='markers',
name='Test Data (accuracy={:.3f})'.format(train_score),
text=y[test_idx],
marker=dict(
size=10,
symbol='triangle-up',
color=y_proc[test_idx],
colorscale=bright_cscale,
line=dict(
width=1
),
)
)
data = [trace0, trace1, trace2]
figure = go.Figure(data=data, layout=layout)
return figure
构建Dash视图时会调用哪个:
def generate_dense_maps():
return html.Div(
className='row',
style={
'margin-top': '5px',
# Remove possibility to select the text for better UX
'user-select': 'none',
'-moz-user-select': 'none',
'-webkit-user-select': 'none',
'-ms-user-select': 'none'
},
children=[
html.Div(
[
dcc.Graph(
id='graph-{name}'.format(name=clf_name),
figure=serve_prediction_plot(clf,
clf_name,
service.dataset.X,
service.dataset.X_proj,
service.dataset.y,
service.dataset.y_proc,
service.dataset.train_idx,
service.dataset.test_idx,
service.get_prediction(clf),
service.grid.xx,
service.grid.yy,
service.x_min,
service.y_min,
service.grid.h),
)
],
className="two columns"
) for clf_name, clf in service.classifiers.items()
]
)
# -------------------- Dash --------------------
app = dash.Dash(__name__)
app.layout = html.Div(children=[
# -------------------- Title Bar --------------------
html.Div(className="banner", children=[
html.Div(className='container scalable', children=[
html.H2(html.A(
'Title goes here',
style={
'text-decoration': 'none',
'color': 'inherit'
}
)),
html.A(
html.Img(src="https://s3-us-west-1.amazonaws.com/plotly-tutorials/logo/new-branding/dash-logo-by-plotly-stripe-inverted.png"),
href='https://plot.ly/products/dash/'
)
]),
]),
# -------------------- Body -------------------------
html.Div(id='body', className='container scalable', children=[
html.Div(className='row', children=[
# -------------------- Classifiers ------------------
html.Div(
id='div-classifiers', children=[
html.H4(html.A(
'Classifiers',
style={
'text-decoration': 'none',
'color': 'inherit'
}
)),
generate_dense_maps()
]
),
# -------------------- Uncertainty ------------------
html.Div(
id='div=uncertainty'
)
])
])
])
但是,图像被剪切掉了
我想知道我缺少什么或如何正确实现所需的输出。 我还尝试绘制这样的内容(我认为实际上在网络上看起来会更好):
image1 | image2
image3 | image4
image5 | image6
没有运气。
最小示例
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Decision Tree",
"Random Forest"]
classifiers = [
KNeighborsClassifier(3),
SVC(kernel="linear", C=0.025, probability=True),
SVC(gamma=2, C=1, probability=True),
DecisionTreeClassifier(max_depth=5),
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
]
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)
h = .02
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4)
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
import plotly.graph_objs as go
import colorlover as cl
from sklearn.model_selection import cross_val_score
def serve_prediction_plot(model, title, X_train, y_train, X_test, y_test, xx, yy, d):
# Get train and test score from model
model.fit(X_train, y_train)
train_score = cross_val_score(model, X_train, y_train).mean()
test_score = model.score(X_test, y_test)
# Colorscale
bright_cscale = [[0, '#FF0000'], [1, '#0000FF']]
colorscale_zip = zip(np.arange(0, 1.01, 1 / 8), cl.scales['9']['div']['RdBu'])
cscale = list(map(list, colorscale_zip))
axis_template = dict(
showgrid=False,
zeroline=False,
linecolor='white',
showticklabels=False,
ticks=''
)
layout = dict(
title=title,
xaxis=axis_template,
yaxis=axis_template,
showlegend=False,
hovermode='closest',
autosize=False,
margin=dict(l=0, r=0, t=30, b=0)
)
# Plot the prediction contour of the models
try:
Z = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
except NotImplementedError:
Z = model.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
print(Z.shape)
trace0 = go.Heatmap(
z=Z,
hoverinfo='none',
showscale=False,
colorscale=cscale,
x0=xx.min(),
y0=yy.min(),
dx=d,
dy=d
)
# Plot Training Data
trace1 = go.Scatter(
x=X_train[:, 0],
y=X_train[:, 1],
mode='markers',
name='Training Data (accuracy={:.3f})'.format(train_score),
text=y_train,
marker=dict(
size=10,
color=y_train,
colorscale=bright_cscale,
line=dict(
width=1
)
)
)
# Plot Test Data
trace2 = go.Scatter(
x=X_test[:, 0],
y=X_test[:, 1],
mode='markers',
name='Test Data (accuracy={:.3f})'.format(train_score),
text=y_test,
marker=dict(
size=10,
symbol='triangle-up',
color=y_test,
colorscale=bright_cscale,
line=dict(
width=1
),
)
)
data = [trace0, trace1, trace2]
figure = go.Figure(data=data, layout=layout)
return figure
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
def generate_dense_maps():
return html.Div(
className='row',
style={
'margin-top': '5px',
# Remove possibility to select the text for better UX
'user-select': 'none',
'-moz-user-select': 'none',
'-webkit-user-select': 'none',
'-ms-user-select': 'none'
},
children=[
html.Div(
[
dcc.Graph(
id='graph-{name}'.format(name=clf_name),
figure=serve_prediction_plot(clf, clf_name, X_train, y_train, X_test, y_test, xx, yy, h),
)
],
className="two columns"
) for clf_name, clf in zip(names, classifiers)
]
)
# -------------------- Dash --------------------
app = dash.Dash(__name__)
app.layout = html.Div(children=[
# -------------------- Title Bar --------------------
html.Div(className="banner", children=[
html.Div(className='container scalable', children=[
html.H2(html.A(
'Title goes here',
style={
'text-decoration': 'none',
'color': 'inherit'
}
)),
html.A(
html.Img(src="https://s3-us-west-1.amazonaws.com/plotly-tutorials/logo/new-branding/dash-logo-by-plotly-stripe-inverted.png"),
href='https://plot.ly/products/dash/'
)
]),
]),
# -------------------- Body -------------------------
html.Div(id='body', className='container scalable', children=[
html.Div(className='row', children=[
# -------------------- Classifiers ------------------
html.Div(
id='div-classifiers', children=[
html.H4(html.A(
'Classifiers',
style={
'text-decoration': 'none',
'color': 'inherit'
}
)),
generate_dense_maps()
])
])
])
])
external_css = [
# Normalize the CSS
"https://cdnjs.cloudflare.com/ajax/libs/normalize/7.0.0/normalize.min.css",
# Fonts
"https://fonts.googleapis.com/css?family=Open+Sans|Roboto",
"https://maxcdn.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css",
# Base Stylesheet, replace this with your own base-styles.css using Rawgit
"https://rawgit.com/xhlulu/9a6e89f418ee40d02b637a429a876aa9/raw/f3ea10d53e33ece67eb681025cedc83870c9938d/base-styles.css",
# Custom Stylesheet, replace this with your own custom-styles.css using Rawgit
"https://cdn.rawgit.com/plotly/dash-svm/bb031580/custom-styles.css"
]
for css in external_css:
app.css.append_css({"external_url": css})
app.run_server(debug=True)