使用Scikit - 学习确定RF模型中每个类的要素重要性

时间:2018-05-06 16:20:44

标签: python machine-learning scikit-learn random-forest binary-data

我有dataset遵循单热编码模式,我的因变量也是二进制。我的代码的第一部分列出了整个数据集的重要变量。我使用了此stackoverflow帖子中提到的方法“Using scikit to determine contributions of each feature to a specific class prediction”。我不确定我得到了什么输出。在我的案例中,功能重要性排列了整个模型的最重要特征,“延迟相关DMS与建议”。我把它解释为,这个变量在Class 0或Class 1中应该很重要,但是从我得到的输出中,它在两个类中都不重要。我在上面分享的stackoverflow中的代码也表明,当DV是二进制时,Class 0的输出与Class 1完全相反(就符号+/-而言)。在我的例子中,两者的值都不同类。

以下是这些情节的样子: -

功能重要性 - 整体模型

Feature Importance - Overall Model

功能重要性 - 0级 Feature Importance - Class 0

功能重要性 - 1级 Feature Importance - Class 1

我的代码的第二部分显示累积特征重要性,但查看[plot]显示没有变量是重要的。我的公式错了或我的解释错了还是两者兼而有之?

plot

这是我的代码;

import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import scale
from sklearn.ensemble import ExtraTreesClassifier


##get_ipython().run_line_magic('matplotlib', 'inline')

file = r'RCM_Binary.csv'
data = pd.read_csv()
print("data loaded successfully ...")

# Define features and target
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

#split to training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=41)

# define classifier and fitting data
forest = ExtraTreesClassifier(random_state=1)
forest.fit(X_train, y_train)

# predict and get confusion matrix
y_pred = forest.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
print(cm)

#Applying 10-fold cross validation
accuracies = cross_val_score(estimator=forest, X=X_train, y=y_train, cv=10)
print("accuracy (10-fold): ", np.mean(accuracies))

# Features importances
importances = forest.feature_importances_
std = np.std([tree.feature_importances_ for tree in forest.estimators_],
             axis=0)
indices = np.argsort(importances)[::-1]
feature_list = [X.columns[indices[f]] for f in range(X.shape[1])]  #names of features.
ff = np.array(feature_list)

# Print the feature ranking
print("Feature ranking:")

for f in range(X.shape[1]):
    print("%d. feature %d (%f) name: %s" % (f + 1, indices[f], importances[indices[f]], ff[indices[f]]))


# Plot the feature importances of the forest
plt.figure()
plt.rcParams['figure.figsize'] = [16, 6]
plt.title("Feature importances")
plt.bar(range(X.shape[1]), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(X.shape[1]), ff[indices], rotation=90)
plt.xlim([-1, X.shape[1]])
plt.show()


## The new additions to get feature importance to classes: 

# To get the importance according to each class:
def class_feature_importance(X, Y, feature_importances):
    N, M = X.shape
    X = scale(X)

    out = {}
    for c in set(Y):
        out[c] = dict(
            zip(range(N), np.mean(X[Y==c, :], axis=0)*feature_importances)
        )

    return out

result = class_feature_importance(X, y, importances)
print (json.dumps(result,indent=4))

# Plot the feature importances of the forest

titles = ["Did not Divert", "Diverted"]
for t, i in zip(titles, range(len(result))):
    plt.figure()
    plt.rcParams['figure.figsize'] = [16, 6]
    plt.title(t)
    plt.bar(range(len(result[i])), result[i].values(),
           color="r", align="center")
    plt.xticks(range(len(result[i])), ff[list(result[i].keys())], rotation=90)
    plt.xlim([-1, len(result[i])])
    plt.show()

代码的第二部分

# List of tuples with variable and importance
feature_importances = [(feature, round(importance, 2)) for feature, importance in zip(feature_list, importances)]
# Sort the feature importances by most important first
feature_importances = sorted(feature_importances, key = lambda x: x[1], reverse = True)
# Print out the feature and importances 
[print('Variable: {:20} Importance: {}'.format(*pair)) for pair in feature_importances]

# list of x locations for plotting
x_values = list(range(len(importances)))
# Make a bar chart
plt.bar(x_values, importances, orientation = 'vertical', color = 'r', edgecolor = 'k', linewidth = 1.2)
# Tick labels for x axis
plt.xticks(x_values, feature_list, rotation='vertical')
# Axis labels and title
plt.ylabel('Importance'); plt.xlabel('Variable'); plt.title('Variable Importances');


# List of features sorted from most to least important
sorted_importances = [importance[1] for importance in feature_importances]
sorted_features = [importance[0] for importance in feature_importances]
# Cumulative importances
cumulative_importances = np.cumsum(sorted_importances)
# Make a line graph
plt.plot(x_values, cumulative_importances, 'g-')
# Draw line at 95% of importance retained
plt.hlines(y = 0.95, xmin=0, xmax=len(sorted_importances), color = 'r', linestyles = 'dashed')
# Format x ticks and labels
plt.xticks(x_values, sorted_features, rotation = 'vertical')
# Axis labels and title
plt.xlabel('Variable'); plt.ylabel('Cumulative Importance'); plt.title('Cumulative Importances');
plt.show()
# Find number of features for cumulative importance of 95%
# Add 1 because Python is zero-indexed
print('Number of features for 95% importance:', np.where(cumulative_importances > 0.95)[0][0] + 1)

2 个答案:

答案 0 :(得分:0)

这个问题可能已经过时,但是以防万一有人感兴趣:

从源代码复制的class_feature_importance函数使用线作为示例的特征和列,而大多数人则相反。因此,按类计算功能重要性的方法会出错。将代码更改为

zip(范围(M)

应该解决它。

答案 1 :(得分:0)

还要确保您的y变量不是数组。 如果是数组,请使用(node:315) UnhandledPromiseRejectionWarning: TypeError: Cannot read property 'hasPermission' of undefined at Object.execute (/home/runner/NonstopPastelDistributionsoftware/commands/addrole.js:7:20) at Client.<anonymous> (/home/runner/NonstopPastelDistributionsoftware/index.js:80:13) at Client.emit (events.js:315:20) at Client.EventEmitter.emit (domain.js:483:12) at MessageCreateAction.handle (/home/runner/NonstopPastelDistributionsoftware/node_modules/discord.js/src/client/actions/MessageCreate.js:31:14) at Object.module.exports [as MESSAGE_CREATE] (/home/runner/NonstopPastelDistributionsoftware/node_modules/discord.js/src/client/websocket/handlers/MESSAGE_CREATE.js:4:32) at WebSocketManager.handlePacket (/home/runner/NonstopPastelDistributionsoftware/node_modules/discord.js/src/client/websocket/WebSocketManager.js:384:31) at WebSocketShard.onPacket (/home/runner/NonstopPastelDistributionsoftware/node_modules/discord.js/src/client/websocket/WebSocketShard.js:444:22) at WebSocketShard.onMessage (/home/runner/NonstopPastelDistributionsoftware/node_modules/discord.js/src/client/websocket/WebSocketShard.js:301:10) at WebSocket.onMessage (/home/runner/NonstopPastelDistributionsoftware/node_modules/ws/lib/event-target.js:125:16) (node:315) UnhandledPromiseRejectionWarning: Unhandled promise rejection. This error originated either by throwing inside of an async function without a catch block, or by rejecting a promise which was not handled with .catch(). To terminate the node process on unhandled promise rejection, use the CLI flag `--unhandled-rejections=strict` (see https://nodejs.org/api/cli.html#cli_unhandled_rejections_mode). (rejection id: 1) (node:315) [DEP0018] DeprecationWarning: Unhandled promise rejections are deprecated. In the future, promise rejections that are not handled will terminate the Node.js process with a non-zero exit code.