如何从pyspark管道(stringindexer-> onehotencoder-> vectorassembler)中获取原始分类变量值?

时间:2020-11-10 18:45:48

标签: python apache-spark machine-learning pyspark apache-spark-sql

我的分类变量有很多不同的层次,我想知道我的模型指向哪个值:

indexer = StringIndexer(inputCols=["a","b"], outputCols=["a_ind", "b_ind"])
df = indexer.setHandleInvalid("keep").fit(df).transform(df)

encoder = OneHotEncoder(inputCols=["a_ind", "b_ind"],
                        outputCols=["a_vec", "b_vec"])
df = encoder.fit(df).transform(df)

df = VectorAssembler(inputCols=["a_vec","b_vec"],
                     outputCol="features").transform(df)

功能列如下:

0: 0
1: 45568
2: 
0: 1
1: 2923
3: 
0: 1
1: 1

当我打印决策树模型时,我看到下面的树

print(dtModel.toDebugString)


DecisionTreeClassificationModel: uid=DecisionTreeClassifier_e9f4ea2ba51e, depth=5, numNodes=13, numClasses=2, numFeatures=45568
  If (feature 0 in {1.0})
   Predict: 0.0
  Else (feature 0 not in {1.0})
   If (feature 1 in {1.0})
    Predict: 0.0
   Else (feature 1 not in {1.0})
    If (feature 3 in {1.0})
     Predict: 0.0
    Else (feature 3 not in {1.0})
     If (feature 5 in {1.0})
      If (feature 3664 in {1.0})
       Predict: 1.0
      Else (feature 3664 not in {1.0})
       Predict: 0.0
     Else (feature 5 not in {1.0})
      If (feature 2933 in {1.0})
       Predict: 1.0
      Else (feature 2933 not in {1.0})
       Predict: 0.0

我希望能够使用这棵树来知道哪些值正在触发1.0的预测(功能3664,功能2993,功能5,功能3等)。如何才能返回到a或b列中的原始字符串?

0 个答案:

没有答案