如何预处理训练数据,使用它训练多标签决策树,然后将结果树转换为PMML?

时间:2020-07-02 13:48:57

标签: python-3.x scikit-learn decision-tree multilabel-classification pmml

我是一个团队的一部分,该团队正在创建一个应用程序来陪伴中风患者进行康复。其中一个组成部分是创建一种算法,以根据某些临床数据建议治疗方案。我有来自真实和计算机病患者的一些临床数据,我们记录了以下输入:

  • 33次运动的运动障碍(Fugl-Meyer上肢或FMUE)得分,可以是0(完全损伤),1(部分损伤)或2(无损伤):FMUE_1,...,FMUE_33
  • 痉挛度(改良的Ashworth量表或MAS),用于手指屈伸,手指伸直,腕部屈伸,腕部伸直,肘部屈伸,肘部伸直,肩膀前倾,肩膀后倾,肩膀外展,肩膀内收,肩膀内旋,肩膀外旋。每个分数均从列表0(肌肉完全柔韧性),1、1.5、2、3、4(肌肉完全柔韧性)中选择
  • 抑郁(医院焦虑症和抑郁量表-抑郁)的得分为7个问题,每个问题的回答为0(完全同意),1、2、3(完全不同意):HADS_D_1,...,HADS_D_7
  • 工作记忆(Corsi评分):分数从0(无记忆)到9(非常好记忆)
  • 肩膀疼痛(视觉评估量表或VAS):评分范围为0(无疼痛),...,10(可想象的最严重疼痛)
  • 癫痫的终生患病率:0(无癫痫发作),1(无癫痫发作)
  • 愿意考虑实验治疗的意愿:0(不愿意),1(愿意)

我设想了一些功能工程,如下所示。

  1. 会有一个“邻近度指数”,以大致相等的量度来判断运动障碍是主要是近端(肩膀,肘部),主要是远端(手腕,手指)还是近端和远端。 FMUE分数1到18衡量近端损伤,而FMUE分数19到33衡量远端损伤,因此近端指数为(FMUE_18,...,FMUE_33)的平均值减去(FMUE_1,...,FMUE_18)的平均值。使用近邻索引,我们将定义两个布尔值: PROXIMAL ,如果近邻索引大于0.2,则为1,否则为0。和 DISTAL ,如果近邻指数小于-0.2,则为1,否则为0。请注意,这两个布尔值都可以为0(如果近邻指数在-0.2和0.2之间),但不能都为1。
  2. 我们通过将 FMUE 定义为(FMUE_1,...,FMUE_33)的总和来评估整体运动障碍。
  3. 我们会通过将 MAS 定义为(手指屈曲,...,肩外旋)的最大值来评估总体痉挛状态。
  4. 我们通过将 HADS_D 定义为(HADS_D_1,...,HADS_D_7)的总和来评估总体抑郁症。
  5. CORSI VAS EXPERIMENTAL EPILEPSY 的功能保持不变。

每位患者的培训标签是以下一组治疗的子集:

  • 时代(一种实验性治疗)
  • ECOSS(另一种实验方法)
  • AVANCER(还有另一种治疗方法)
  • 约束运动疗法
  • PROXIMAL_ES(对近端肌肉的电刺激)
  • DISTAL_ES(对远端肌肉的电刺激)
  • MIRROR(镜子疗法,一种使用镜子欺骗大脑以认为肢体受损的人和健康者一样有效的人)
  • 心理治疗(治疗抑郁症)

为确保最大程度的解释性,我认为多标签决策树是必经之路。使用scikit-learn,我创建了功能工程流水线,后跟分类器并成功地对其进行了训练。问题是我还需要将模型导出到PMML,以提供给生产团队。我尝试了sklearn2pmml,但失败了,我决定尝试从头开始写PMML文件。不幸的是,事实证明这也比我想的要难。这是我到目前为止的内容:

<PMML version="4.4">
  <Header/>
  <DataDictionary>
    <!-- non-preprocessed fields -->
    <DataField name="FMUE_1" optype="continuous" dataType="double"/>
    <DataField name="FMUE_2" optype="continuous" dataType="double"/>
    <DataField name="FMUE_3" optype="continuous" dataType="double"/>
    <DataField name="FMUE_4" optype="continuous" dataType="double"/>
    <DataField name="FMUE_5" optype="continuous" dataType="double"/>
    <DataField name="FMUE_6" optype="continuous" dataType="double"/>
    <DataField name="FMUE_7" optype="continuous" dataType="double"/>
    <DataField name="FMUE_8" optype="continuous" dataType="double"/>
    <DataField name="FMUE_9" optype="continuous" dataType="double"/>
    <DataField name="FMUE_10" optype="continuous" dataType="double"/>
    <DataField name="FMUE_11" optype="continuous" dataType="double"/>
    <DataField name="FMUE_12" optype="continuous" dataType="double"/>
    <DataField name="FMUE_13" optype="continuous" dataType="double"/>
    <DataField name="FMUE_14" optype="continuous" dataType="double"/>
    <DataField name="FMUE_15" optype="continuous" dataType="double"/>
    <DataField name="FMUE_16" optype="continuous" dataType="double"/>
    <DataField name="FMUE_17" optype="continuous" dataType="double"/>
    <DataField name="FMUE_18" optype="continuous" dataType="double"/>
    <DataField name="FMUE_19" optype="continuous" dataType="double"/>
    <DataField name="FMUE_20" optype="continuous" dataType="double"/>
    <DataField name="FMUE_21" optype="continuous" dataType="double"/>
    <DataField name="FMUE_22" optype="continuous" dataType="double"/>
    <DataField name="FMUE_23" optype="continuous" dataType="double"/>
    <DataField name="FMUE_24" optype="continuous" dataType="double"/>
    <DataField name="FMUE_25" optype="continuous" dataType="double"/>
    <DataField name="FMUE_26" optype="continuous" dataType="double"/>
    <DataField name="FMUE_27" optype="continuous" dataType="double"/>
    <DataField name="FMUE_28" optype="continuous" dataType="double"/>
    <DataField name="FMUE_29" optype="continuous" dataType="double"/>
    <DataField name="FMUE_30" optype="continuous" dataType="double"/>
    <DataField name="FMUE_31" optype="continuous" dataType="double"/>
    <DataField name="FMUE_32" optype="continuous" dataType="double"/>
    <DataField name="FMUE_33" optype="continuous" dataType="double"/>
    <DataField name="FINGER_FLEXION_MAS" optype="continuous" dataType="double"/>
    <DataField name="FINGER_EXTENSION_MAS" optype="continuous" dataType="double"/>
    <DataField name="WRIST_FLEXION_MAS" optype="continuous" dataType="double"/>
    <DataField name="WRIST_EXTENSION_MAS" optype="continuous" dataType="double"/>
    <DataField name="ELBOW_FLEXION_MAS" optype="continuous" dataType="double"/>
    <DataField name="ELBOW_EXTENSION_MAS" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_ANTEVERSION_MAS" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_RETROVERSION_MAS" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_ABDUCTION_MAS" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_ADDUCTION_MAS" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_INNER_ROTATION" optype="continuous" dataType="double"/>
    <DataField name="SHOULDER_OUTER_ROTATION" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_1" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_2" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_3" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_4" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_5" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_6" optype="continuous" dataType="double"/>
    <DataField name="HADS_D_7" optype="continuous" dataType="double"/>
    <DataField name="CORSI" optype="continuous" dataType="double"/>
    <DataField name="VAS" optype="continuous" dataType="double"/>
    <DataField name="EPILEPSY" optype="continuous" dataType="double"/>
    <DataField name="EXPERIMENTAL" optype="continuous" dataType="double"/>
    <!-- preprocessed fields -->
    <DataField name="PROXIMAL" optype="continuous" dataType="double"/>
    <DataField name="DISTAL" optype="continuous" dataType="double"/>
    <DataField name="FMUE" optype="continuous" dataType="double"/>
    <DataField name="FMUE" optype="continuous" dataType="double"/>
    <DataField name="MAS" optype="continuous" dataType="double"/>
    <DataField name="HADS_D" optype="continuous" dataType="double"/>
    <!-- labels -->
    <DataField name="class" optype="categorical" dataType="string">
      <Value value="TIMES"/>
      <Value value="ECOSS"/>
      <Value value="AVANCER"/>
      <Value value="CIMT"/>
      <Value value="PROXIMAL_ES"/>
      <Value value="DISTAL_ES"/>
      <Value value="MIRROR"/>
      <Value value="PSYCHOTHERAPY"/>
    </DataField>
  </DataDictionary>
  <TransformationDictionary>
    <DerivedField name="PROXIMAL" optype="continuous" dataType="double">
      <Apply function="threshold">
        <Constant dataType="double">
          -0.2
        </Constant>
        <Apply function="-">
          <Apply function="avg">
            <FieldRef field="FMUE_1"/>
            <FieldRef field="FMUE_2"/>
            <FieldRef field="FMUE_3"/>
            <FieldRef field="FMUE_4"/>
            <FieldRef field="FMUE_5"/>
            <FieldRef field="FMUE_6"/>
            <FieldRef field="FMUE_7"/>
            <FieldRef field="FMUE_8"/>
            <FieldRef field="FMUE_9"/>
            <FieldRef field="FMUE_10"/>
            <FieldRef field="FMUE_11"/>
            <FieldRef field="FMUE_12"/>
            <FieldRef field="FMUE_13"/>
            <FieldRef field="FMUE_14"/>
            <FieldRef field="FMUE_15"/>
            <FieldRef field="FMUE_16"/>
            <FieldRef field="FMUE_17"/>
            <FieldRef field="FMUE_18"/>
          </Apply>
          <Apply function="avg">
            <FieldRef field="FMUE_19"/>
            <FieldRef field="FMUE_20"/>
            <FieldRef field="FMUE_21"/>
            <FieldRef field="FMUE_22"/>
            <FieldRef field="FMUE_23"/>
            <FieldRef field="FMUE_24"/>
            <FieldRef field="FMUE_25"/>
            <FieldRef field="FMUE_26"/>
            <FieldRef field="FMUE_27"/>
            <FieldRef field="FMUE_28"/>
            <FieldRef field="FMUE_29"/>
            <FieldRef field="FMUE_30"/>
            <FieldRef field="FMUE_31"/>
            <FieldRef field="FMUE_32"/>
            <FieldRef field="FMUE_33"/>
          </Apply>
        </Apply>
      </Apply>
    </DerivedField>
    <DerivedField name="DISTAL" optype="continuous" dataType="double">
      <Apply function="threshold">
        <Apply function="-">
          <Apply function="avg">
            <FieldRef field="FMUE_1"/>
            <FieldRef field="FMUE_2"/>
            <FieldRef field="FMUE_3"/>
            <FieldRef field="FMUE_4"/>
            <FieldRef field="FMUE_5"/>
            <FieldRef field="FMUE_6"/>
            <FieldRef field="FMUE_7"/>
            <FieldRef field="FMUE_8"/>
            <FieldRef field="FMUE_9"/>
            <FieldRef field="FMUE_10"/>
            <FieldRef field="FMUE_11"/>
            <FieldRef field="FMUE_12"/>
            <FieldRef field="FMUE_13"/>
            <FieldRef field="FMUE_14"/>
            <FieldRef field="FMUE_15"/>
            <FieldRef field="FMUE_16"/>
            <FieldRef field="FMUE_17"/>
            <FieldRef field="FMUE_18"/>
          </Apply>
          <Apply function="avg">
            <FieldRef field="FMUE_19"/>
            <FieldRef field="FMUE_20"/>
            <FieldRef field="FMUE_21"/>
            <FieldRef field="FMUE_22"/>
            <FieldRef field="FMUE_23"/>
            <FieldRef field="FMUE_24"/>
            <FieldRef field="FMUE_25"/>
            <FieldRef field="FMUE_26"/>
            <FieldRef field="FMUE_27"/>
            <FieldRef field="FMUE_28"/>
            <FieldRef field="FMUE_29"/>
            <FieldRef field="FMUE_30"/>
            <FieldRef field="FMUE_31"/>
            <FieldRef field="FMUE_32"/>
            <FieldRef field="FMUE_33"/>
          </Apply>
        </Apply>
        <Constant dataType="double">
          0.2
        </Constant>        
      </Apply>
    </DerivedField>
    <DerivedField name="FMUE" optype="continuous" dataType="double">
      <Apply function="sum">
        <FieldRef field="FMUE_1"/>
        <FieldRef field="FMUE_2"/>
        <FieldRef field="FMUE_3"/>
        <FieldRef field="FMUE_4"/>
        <FieldRef field="FMUE_5"/>
        <FieldRef field="FMUE_6"/>
        <FieldRef field="FMUE_7"/>
        <FieldRef field="FMUE_8"/>
        <FieldRef field="FMUE_9"/>
        <FieldRef field="FMUE_10"/>
        <FieldRef field="FMUE_11"/>
        <FieldRef field="FMUE_12"/>
        <FieldRef field="FMUE_13"/>
        <FieldRef field="FMUE_14"/>
        <FieldRef field="FMUE_15"/>
        <FieldRef field="FMUE_16"/>
        <FieldRef field="FMUE_17"/>
        <FieldRef field="FMUE_18"/>
        <FieldRef field="FMUE_19"/>
        <FieldRef field="FMUE_20"/>
        <FieldRef field="FMUE_21"/>
        <FieldRef field="FMUE_22"/>
        <FieldRef field="FMUE_23"/>
        <FieldRef field="FMUE_24"/>
        <FieldRef field="FMUE_25"/>
        <FieldRef field="FMUE_26"/>
        <FieldRef field="FMUE_27"/>
        <FieldRef field="FMUE_28"/>
        <FieldRef field="FMUE_29"/>
        <FieldRef field="FMUE_30"/>
        <FieldRef field="FMUE_31"/>
        <FieldRef field="FMUE_32"/>
        <FieldRef field="FMUE_33"/>
      </Apply>
    </DerivedField>
    <DerivedField name="MAS" optype="continuous" dataType="double">
      <Apply function="max">
        <FieldRef field="FINGER_FLEXION_MAS"/>
        <FieldRef field="FINGER_EXTENSION_MAS"/>
        <FieldRef field="WRIST_FLEXION_MAS"/>
        <FieldRef field="WRIST_EXTENSION_MAS"/>
        <FieldRef field="ELBOW_FLEXION_MAS"/>
        <FieldRef field="ELBOW_EXTENSION_MAS"/>
        <FieldRef field="SHOULDER_ANTEVERSION_MAS"/>
        <FieldRef field="SHOULDER_RETROVERSION_MAS"/>
        <FieldRef field="SHOULDER_ABDUCTION_MAS"/>
        <FieldRef field="SHOULDER_ADDUCTION_MAS"/>
        <FieldRef field="SHOULDER_INNER_ROTATION"/>
        <FieldRef field="SHOULDER_OUTER_ROTATION"/>
      </Apply>
    </DerivedField>
    <DerivedField name="HADS_D" optype="continuous" dataType="double">
      <Apply function="sum">
        <FieldRef field="HADS_D_1"/>
        <FieldRef field="HADS_D_2"/>
        <FieldRef field="HADS_D_3"/>
        <FieldRef field="HADS_D_4"/>
        <FieldRef field="HADS_D_5"/>
        <FieldRef field="HADS_D_6"/>
        <FieldRef field="HADS_D_7"/>
      </Apply>
    </DerivedField>
  </TransformationDictionary>
  <TreeModel functionName="classification">
    <MiningSchema>
      <MiningField name="PROXIMAL" missingValueReplacement="0"/>
      <MiningField name="DISTAL" missingValueReplacement="0"/>
      <MiningField name="FMUE" missingValueReplacement="22"/>
      <MiningField name="MAS" missingValueReplacement="0"/>
      <MiningField name="HADS_D" missingValueReplacement="0"/>
      <MiningField name="CORSI" missingValueReplacement="3"/>
      <MiningField name="VAS" missingValueReplacement="0"/>
      <MiningField name="EPILEPSY" missingValueReplacement="1"/>
      <MiningField name="EXPERIMENTAL" missingValueReplacement="0"/>
      <MiningField name="class" usageType="predicted"/>
    </MiningSchema>
    <Node id="0" recordCount="2000">
      <True/>
      <ScoreDistribution value="TIMES" recordCount="166"/>
      <ScoreDistribution value="ECOSS" recordCount="369"/>
      <ScoreDistribution value="AVANCER" recordCount="465"/>
      <ScoreDistribution value="CIMT" recordCount="692"/>
      <ScoreDistribution value="PROXIMAL_ES" recordCount="286"/>
      <ScoreDistribution value="DISTAL_ES" recordCount="922"/>
      <ScoreDistribution value="MIRROR" recordCount="142"/>
      <ScoreDistribution value="PSYCHOTHERAPY" recordCount="28"/>
      <Node id="1" recordCount="1000">
        <SimplePredicate field="EXPERIMENTAL" operator="equal" value="1"/>
        <ScoreDistribution value="TIMES" recordCount="166"/>
        <ScoreDistribution value="ECOSS" recordCount="369"/>
        <ScoreDistribution value="AVANCER" recordCount="465"/>
        <ScoreDistribution value="CIMT" recordCount="346"/>
        <ScoreDistribution value="PROXIMAL_ES" recordCount="143"/>
        <ScoreDistribution value="DISTAL_ES" recordCount="461"/>
        <ScoreDistribution value="MIRROR" recordCount="71"/>
        <ScoreDistribution value="PSYCHOTHERAPY" recordCount="14"/>
      </Node>
      <Node id="2" recordCount="1000">
        <SimplePredicate field="EXPERIMENTAL" operator="equal" value="0"/>
        <ScoreDistribution value="TIMES" recordCount="0"/>
        <ScoreDistribution value="ECOSS" recordCount="0"/>
        <ScoreDistribution value="AVANCER" recordCount="0"/>
        <ScoreDistribution value="CIMT" recordCount="346"/>
        <ScoreDistribution value="PROXIMAL_ES" recordCount="143"/>
        <ScoreDistribution value="DISTAL_ES" recordCount="461"/>
        <ScoreDistribution value="MIRROR" recordCount="71"/>
        <ScoreDistribution value="PSYCHOTHERAPY" recordCount="14"/>
      </Node>
    </Node>
  </TreeModel>
</PMML>

我想通过在各处输入1的患者并验证该患者最终进入结点1来进行测试,以证明其概率为TIMES的概率为0.166,ECOSS的概率为0.369,等等……概率不一定加起来为1(待遇并非互斥)。

我将文档保存到MotorTree.pmml,然后在python中运行

from pypmml import Model
clf = Model.fromFile("MotorTree.pmml")
cols = ['FMUE_%d'%i for i in range(1,33+1)]\
+['FINGER_FLEXION_MAS', 'FINGER_EXTENSION_MAS','WRIST_FLEXION_MAS','WRIST_EXTENSION_MAS','ELBOW_FLEXION_MAS',\
  'ELBOW_EXTENSION_MAS', 'SHOULDER_ANTEVERSION_MAS','SHOULDER_RETROVERSION_MAS', 'SHOULDER_ABDUCTION_MAS',\
  'SHOULDER_ADDUCTION_MAS', 'SHOULDER_INNER_ROTATION_MAS','SHOULDER_OUTER_ROTATION_MAS']\
+['HADS_D_%d'%i for i in range(1,7+1)]\
+['VAS','CORSI','EXPERIMENTAL','EPILEPSY']
clf.predict(pd.DataFrame(index=[0], columns=cols, data=0))

返回了数据框

node_id     predicted_class     probability     probability_AVANCER     probability_CIMT    probability_DISTAL_ES   probability_ECOSS   probability_MIRROR  probability_PROXIMAL_ES     probability_PSYCHOTHERAPY   probability_TIMES
0   2   None    NaN     0.0     0.3343  0.445411    0.0     0.068599    0.138164    0.013527    0.0

似乎有一个自动缩放以使概率加起来为1,但是我不希望这样。我该如何解决?

如果您已经彻底了解了这一点,那么请衷心地感谢您,如果您能提供答案的话,甚至还要更热情地感谢您!

0 个答案:

没有答案
相关问题