在python sklearn部分依赖图中更改x标签

时间:2017-11-20 17:04:40

标签: python scikit-learn sklearn-pandas

您使用标准化数据拟合GradientBoostingRegressor并绘制主要10个变量的部分依赖性。现在我想根据真实的非标准化值绘制它们,因此想要访问x标签。我该怎么做?

我的代码与之相当 http://scikit-learn.org/stable/auto_examples/ensemble/plot_partial_dependence.html

对于3D绘图,我可以轻松转换轴

axes[0] = (axes[0]*mysd0)+mymean0
axes[1] = (axes[1]*mysd1)+mymean1

具有均值和标准差但是对于子图我不知道如何访问标签。 THX

这里是我正在讨论的代码部分:

from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble.partial_dependence import plot_partial_dependence
from sklearn.datasets.california_housing import fetch_california_housing

cal_housing = fetch_california_housing()

# split 80/20 train-test
X_train, X_test, y_train, y_test = train_test_split(cal_housing.data,
                                                    cal_housing.target,
                                                    test_size=0.2,
                                                    random_state=1)
names = cal_housing.feature_names
clf = GradientBoostingRegressor(n_estimators=100, max_depth=4,
                                learning_rate=0.1, loss='huber',
                                random_state=1)
clf.fit(X_train, y_train)
features = [0, 5, 1]
fig, axs = plot_partial_dependence(clf, X_train, features,
                                   feature_names=names,
                                   n_jobs=3, grid_resolution=50)
fig.suptitle('Partial dependence of house value on nonlocation features\n'
             'for the California housing dataset')

在此图中,我想访问和操作x轴标签......

2 个答案:

答案 0 :(得分:1)

如果我理解你想要根据功能重要性访问标签

如果是这种情况,那么您可以执行以下操作:

#after fitting the model use this to get the feature importance
feature_importance = clf.feature_importances_

# make importances relative to max importance
feature_importance = 100.0 * (feature_importance / feature_importance.max())

# sort the importances and get the indices of the sorting
sorted_idx = np.argsort(feature_importance)

#match the indices with the labels of the x matrix
#important: x must have columns names to do this
x.columns[feature_names[sorted_idx]]

这将为您提供按升序排列的功能名称。这意味着名字是最不重要的功能,姓氏是最重要的功能。

答案 1 :(得分:0)

我找到了解决方案,而且非常明显......斧头将所有轴信息作为列表包含在内。因此,每个轴都可以被它访问。因此,第一个子图的轴是axs [0]并获得标签:

web.config

然而,这在我的情况下不起作用,虽然图中显示了值,但总是为空。我使用轴限制和以下代码来创建新的转换标签

labels = [item.get_text() for item in axs[0].get_xticklabels()]