如何从决策树模型中提取每个叶子的规则?

时间:2016-06-24 14:32:57

标签: java xslt data-mining pmml

我有pmml格式的决策树模型,如下所示。如何以文本或其他格式保存每个叶子的规则?

例如:uniformitycellsize< = 3.5 ^ clumpthickness< = 6.5 ^ normalnucleoli> = 3.5 => B

  <TreeModel modelName="DecisionTree" functionName="classification" splitCharacteristic="binarySplit" missingValueStrategy="lastPrediction" noTrueChildStrategy="returnNullPrediction">
    <MiningSchema>
      <MiningField name="clumpthickness" invalidValueTreatment="asIs"/>
      <MiningField name="uniformitycellsize" invalidValueTreatment="asIs"/>
      <MiningField name="uniformitycellshape" invalidValueTreatment="asIs"/>
      <MiningField name="marginaladhesion" invalidValueTreatment="asIs"/>
      <MiningField name="epithelialcellsize" invalidValueTreatment="asIs"/>
      <MiningField name="barenuclei" invalidValueTreatment="asIs"/>
      <MiningField name="blandchromatin" invalidValueTreatment="asIs"/>
      <MiningField name="normalnucleoli" invalidValueTreatment="asIs"/>
      <MiningField name="mitoses" invalidValueTreatment="asIs"/>
      <MiningField name="partition" invalidValueTreatment="asIs"/>
      <MiningField name="Class_Categorical" invalidValueTreatment="asIs" usageType="target"/>
    </MiningSchema>
    <Node id="0" score="B" recordCount="559.0">
      <True/>
      <ScoreDistribution value="B" recordCount="365.0"/>
      <ScoreDistribution value="M" recordCount="194.0"/>
      <Node id="1" score="B" recordCount="384.0">
        <SimplePredicate field="uniformitycellsize" operator="lessOrEqual" value="3.5"/>
        <ScoreDistribution value="B" recordCount="356.0"/>
        <ScoreDistribution value="M" recordCount="28.0"/>
        <Node id="2" score="B" recordCount="368.0">
          <SimplePredicate field="clumpthickness" operator="lessOrEqual" value="6.5"/>
          <ScoreDistribution value="B" recordCount="354.0"/>
          <ScoreDistribution value="M" recordCount="14.0"/>
          <Node id="3" score="B" recordCount="353.0">
            <SimplePredicate field="normalnucleoli" operator="lessOrEqual" value="3.5"/>
            <ScoreDistribution value="B" recordCount="347.0"/>
            <ScoreDistribution value="M" recordCount="6.0"/>
          </Node>
          <Node id="10" score="M" recordCount="15.0">
            <SimplePredicate field="normalnucleoli" operator="greaterThan" value="3.5"/>
            <ScoreDistribution value="B" recordCount="7.0"/>
            <ScoreDistribution value="M" recordCount="8.0"/>
          </Node>
        </Node>
        <Node id="11" score="M" recordCount="16.0">
          <SimplePredicate field="clumpthickness" operator="greaterThan" value="6.5"/>
          <ScoreDistribution value="B" recordCount="2.0"/>
          <ScoreDistribution value="M" recordCount="14.0"/>
        </Node>
      </Node>
      <Node id="12" score="M" recordCount="175.0">
        <SimplePredicate field="uniformitycellsize" operator="greaterThan" value="3.5"/>
        <ScoreDistribution value="B" recordCount="9.0"/>
        <ScoreDistribution value="M" recordCount="166.0"/>
        <Node id="13" score="M" recordCount="33.0">
          <SimplePredicate field="uniformitycellsize" operator="lessOrEqual" value="4.5"/>
          <ScoreDistribution value="B" recordCount="7.0"/>
          <ScoreDistribution value="M" recordCount="26.0"/>
          <Node id="14" score="M" recordCount="21.0">
            <SimplePredicate field="marginaladhesion" operator="lessOrEqual" value="5.5"/>
            <ScoreDistribution value="B" recordCount="7.0"/>
            <ScoreDistribution value="M" recordCount="14.0"/>
            <Node id="15" score="B" recordCount="10.0">
              <SimplePredicate field="clumpthickness" operator="lessOrEqual" value="7.5"/>
              <ScoreDistribution value="B" recordCount="6.0"/>
              <ScoreDistribution value="M" recordCount="4.0"/>
            </Node>
            <Node id="16" score="M" recordCount="11.0">
              <SimplePredicate field="clumpthickness" operator="greaterThan" value="7.5"/>
              <ScoreDistribution value="B" recordCount="1.0"/>
              <ScoreDistribution value="M" recordCount="10.0"/>
            </Node>
          </Node>
          <Node id="17" score="M" recordCount="12.0">
            <SimplePredicate field="marginaladhesion" operator="greaterThan" value="5.5"/>
            <ScoreDistribution value="B" recordCount="0.0"/>
            <ScoreDistribution value="M" recordCount="12.0"/>
          </Node>
        </Node>
        <Node id="18" score="M" recordCount="142.0">
          <SimplePredicate field="uniformitycellsize" operator="greaterThan" value="4.5"/>
          <ScoreDistribution value="B" recordCount="2.0"/>
          <ScoreDistribution value="M" recordCount="140.0"/>
        </Node>
      </Node>
    </Node>
  </TreeModel>

=============================================== ============================ 用于实现此类结果的xsl样式表如下所示。

<xsl:stylesheet version="1.0" 
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:output method="text" encoding="UTF-8"/>

<xsl:template match="/">
    <xsl:for-each select="//Node[not(Node)]">
        <xsl:for-each select="ancestor-or-self::Node/SimplePredicate">
            <xsl:value-of select="@field"/>
            <xsl:choose>
                <xsl:when test="@operator = 'lessOrEqual'"> &lt;= </xsl:when>
                <xsl:when test="@operator = 'greaterThan'"> &gt; </xsl:when>
            </xsl:choose>
            <xsl:value-of select="@value"/>
            <xsl:if test="position() != last()">
                <xsl:text> ^ </xsl:text>
            </xsl:if>
            <xsl:if test="position() = last()">
                 <xsl:text> => </xsl:text>
                 <xsl:value-of select="../@score"/>
            </xsl:if>
        </xsl:for-each>
        <xsl:text>&#10;</xsl:text>
    </xsl:for-each>
</xsl:template>

</xsl:stylesheet>

结果是:

Uniformity of Cell Size <= 2.5 ^ Bare Nuclei <= 5.5 => B
Uniformity of Cell Size <= 2.5 ^ Bare Nuclei > 5.5 => M
Uniformity of Cell Size > 2.5 ^ Uniformity of Cell Shape <= 2.5 ^ Clump Thickness <= 5.5 => B
Uniformity of Cell Size > 2.5 ^ Uniformity of Cell Shape <= 2.5 ^ Clump Thickness > 5.5 => M
Uniformity of Cell Size > 2.5 ^ Uniformity of Cell Shape > 2.5 => M

2 个答案:

答案 0 :(得分:1)

在XSLT中,您可以执行以下操作:

XSLT 1.0

<xsl:stylesheet version="1.0" 
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:output method="text" encoding="UTF-8"/>

<xsl:template match="/">
    <xsl:for-each select="//Node[not(Node)]/ScoreDistribution">
        <xsl:for-each select="ancestor::Node/SimplePredicate">
            <xsl:value-of select="@field"/>
            <xsl:choose>
                <xsl:when test="@operator = 'lessOrEqual'"> &lt;= </xsl:when>
                <xsl:when test="@operator = 'greaterThan'"> &gt; </xsl:when>
            </xsl:choose>
            <xsl:value-of select="@value"/>
            <xsl:if test="position() != last()">
                <xsl:text> ^ </xsl:text>
            </xsl:if>
        </xsl:for-each>
        <xsl:text> => </xsl:text>
        <xsl:value-of select="@value"/>
        <xsl:text> (</xsl:text>
        <xsl:value-of select="@recordCount"/>
        <xsl:text>) &#10;</xsl:text>
    </xsl:for-each>
</xsl:template>

</xsl:stylesheet>

应用于您的输入示例,结果将是:

uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli <= 3.5 => B (347.0) 
uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli <= 3.5 => M (6.0) 
uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli > 3.5 => B (7.0) 
uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli > 3.5 => M (8.0) 
uniformitycellsize <= 3.5 ^ clumpthickness > 6.5 => B (2.0) 
uniformitycellsize <= 3.5 ^ clumpthickness > 6.5 => M (14.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness <= 7.5 => B (6.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness <= 7.5 => M (4.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness > 7.5 => B (1.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness > 7.5 => M (10.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion > 5.5 => B (0.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion > 5.5 => M (12.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize > 4.5 => B (2.0) 
uniformitycellsize > 3.5 ^ uniformitycellsize > 4.5 => M (140.0) 

或者,如果您愿意:

<xsl:stylesheet version="1.0" 
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:output method="text" encoding="UTF-8"/>

<xsl:template match="/">
    <xsl:for-each select="//Node[not(Node)]">
        <xsl:for-each select="ancestor-or-self::Node/SimplePredicate">
            <xsl:value-of select="@field"/>
            <xsl:choose>
                <xsl:when test="@operator = 'lessOrEqual'"> &lt;= </xsl:when>
                <xsl:when test="@operator = 'greaterThan'"> &gt; </xsl:when>
            </xsl:choose>
            <xsl:value-of select="@value"/>
            <xsl:if test="position() != last()">
                <xsl:text> ^ </xsl:text>
            </xsl:if>
        </xsl:for-each>
        <xsl:text> => </xsl:text>
        <xsl:for-each select="ScoreDistribution">
            <xsl:value-of select="@value"/>
            <xsl:text> (</xsl:text>
            <xsl:value-of select="@recordCount"/>
            <xsl:text>)</xsl:text>
            <xsl:if test="position() != last()">
                <xsl:text>; </xsl:text>
            </xsl:if>
        </xsl:for-each>
        <xsl:text>&#10;</xsl:text>
    </xsl:for-each>
</xsl:template>

</xsl:stylesheet>

生产:

uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli <= 3.5 => B (347.0); M (6.0)
uniformitycellsize <= 3.5 ^ clumpthickness <= 6.5 ^ normalnucleoli > 3.5 => B (7.0); M (8.0)
uniformitycellsize <= 3.5 ^ clumpthickness > 6.5 => B (2.0); M (14.0)
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness <= 7.5 => B (6.0); M (4.0)
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion <= 5.5 ^ clumpthickness > 7.5 => B (1.0); M (10.0)
uniformitycellsize > 3.5 ^ uniformitycellsize <= 4.5 ^ marginaladhesion > 5.5 => B (0.0); M (12.0)
uniformitycellsize > 3.5 ^ uniformitycellsize > 4.5 => B (2.0); M (140.0)

答案 1 :(得分:0)

您可以编写一个xpath来从xml中获取叶子,并根据获得的信息构造对象。

例如,normalnucleoli的Xpath将是://*[@field][@field='normalnucleoli']/@value

java中使用上述xpath的示例代码将是:

DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
DocumentBuilder db = dbf.newDocumentBuilder();
Document doc = db.parse( new File( TreeModelXmlFile ) );
XPathFactory xPathFactory = XPathFactory.newInstance();
XPath xpath = xPathFactory.newXPath();


String fieldToExtract = "normalnucleoli";
String normalNucleoliValue = "";
XPathExpression expr = xpath.compile( "//*[@field][@field='" + fieldToExtract + "']/@value" );
Object exprEval = expr.evaluate( doc, XPathConstants.NODESET );
if ( exprEval != null && exprEval instanceof NodeList )
{
   NodeList nodeList = (NodeList)exprEval;
   if ( nodeList.getLength() > 0 )
   {
      normalNucleoliValue = nodeList.get(0).getTextContent();
   }
}           
String operator = "";
expr = xpath.compile( ""//*[@field][@field='" + fieldToExtract + "']/@operator"");
Object exprEval = expr.evaluate( doc, XPathConstants.NODESET );
if ( exprEval != null && exprEval instanceof NodeList )
{
   NodeList nodeList = (NodeList)exprEval;
   if ( nodeList.getLength() > 0 )
   {
       operator = nodeList.get(0).getTextContent();
   }            
}

System.out.println( fieldToExtract + " " + operator + " " + normalNucleoliValue );

OR

您可以使用JAXB编写unmarshaller以将xml转换为java对象。你需要一个架构。