我的分类变量有很多不同的层次,我想知道我的模型指向哪个值:
indexer = StringIndexer(inputCols=["a","b"], outputCols=["a_ind", "b_ind"])
df = indexer.setHandleInvalid("keep").fit(df).transform(df)
encoder = OneHotEncoder(inputCols=["a_ind", "b_ind"],
outputCols=["a_vec", "b_vec"])
df = encoder.fit(df).transform(df)
df = VectorAssembler(inputCols=["a_vec","b_vec"],
outputCol="features").transform(df)
功能列如下:
0: 0
1: 45568
2:
0: 1
1: 2923
3:
0: 1
1: 1
当我打印决策树模型时,我看到下面的树
print(dtModel.toDebugString)
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_e9f4ea2ba51e, depth=5, numNodes=13, numClasses=2, numFeatures=45568
If (feature 0 in {1.0})
Predict: 0.0
Else (feature 0 not in {1.0})
If (feature 1 in {1.0})
Predict: 0.0
Else (feature 1 not in {1.0})
If (feature 3 in {1.0})
Predict: 0.0
Else (feature 3 not in {1.0})
If (feature 5 in {1.0})
If (feature 3664 in {1.0})
Predict: 1.0
Else (feature 3664 not in {1.0})
Predict: 0.0
Else (feature 5 not in {1.0})
If (feature 2933 in {1.0})
Predict: 1.0
Else (feature 2933 not in {1.0})
Predict: 0.0
我希望能够使用这棵树来知道哪些值正在触发1.0的预测(功能3664,功能2993,功能5,功能3等)。如何才能返回到a或b列中的原始字符串?