Python 如何从 scikit-learn 决策树中提取决策规则?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/20224526/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-18 19:55:21  来源:igfitidea点击:

How to extract the decision rules from scikit-learn decision-tree?

pythonmachine-learningscikit-learndecision-treerandom-forest

提问by Dror Hilman

Can I extract the underlying decision-rules (or 'decision paths') from a trained tree in a decision tree as a textual list?

我可以从决策树中训练有素的树中提取底层决策规则(或“决策路径”)作为文本列表吗?

Something like:

就像是:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Thanks for your help.

谢谢你的帮助。

采纳答案by paulkernfeld

I believe that this answer is more correct than the other answers here:

我相信这个答案比这里的其他答案更正确:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

This prints out a valid Python function. Here's an example output for a tree that is trying to return its input, a number between 0 and 10.

这将打印出一个有效的 Python 函数。这是试图返回其输入(0 到 10 之间的数字)的树的示例输出。

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Here are some stumbling blocks that I see in other answers:

以下是我在其他答案中看到的一些绊脚石:

  1. Using tree_.threshold == -2to decide whether a node is a leaf isn't a good idea. What if it's a real decision node with a threshold of -2? Instead, you should look at tree.featureor tree.children_*.
  2. The line features = [feature_names[i] for i in tree_.feature]crashes with my version of sklearn, because some values of tree.tree_.featureare -2 (specifically for leaf nodes).
  3. There is no need to have multiple if statements in the recursive function, just one is fine.
  1. 使用tree_.threshold == -2来决定一个节点是否是叶子不是一个好主意。如果它是一个阈值为-2 的真实决策节点呢?相反,您应该查看tree.featuretree.children_*
  2. 该行features = [feature_names[i] for i in tree_.feature]与我的 sklearn 版本崩溃,因为某些值为tree.tree_.feature-2(特别是叶节点)。
  3. 递归函数中没有必要有多个 if 语句,一个就可以了。

回答by lennon310

from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

You can see a digraph Tree. Then, clf.tree_.featureand clf.tree_.valueare array of nodes splitting feature and array of nodes values respectively. You can refer to more details from this github source.

你可以看到一个有向图树。然后,clf.tree_.featureclf.tree_.value分别是分裂特征的节点数组和节点值数组。您可以从这个github 源中参考更多详细信息。

回答by Zelazny7

I created my own function to extract the rules from the decision trees created by sklearn:

我创建了自己的函数来从 sklearn 创建的决策树中提取规则:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:

此函数首先从节点(在子数组中由 -1 标识)开始,然后递归查找父节点。我称之为节点的“血统”。在此过程中,我获取了创建 if/then/else SAS 逻辑所需的值:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

The sets of tuples below contain everything I need to create SAS if/then/else statements. I do not like using doblocks in SAS which is why I create logic describing a node's entire path. The single integer after the tuples is the ID of the terminal node in a path. All of the preceding tuples combine to create that node.

下面的元组集包含创建 SAS if/then/else 语句所需的一切。我不喜欢do在 SAS 中使用块,这就是我创建描述节点整个路径的逻辑的原因。元组后面的单个整数是路径中终端节点的 ID。所有前面的元组结合起来创建该节点。

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

GraphViz output of example tree

示例树的 GraphViz 输出

回答by Daniele

I modified the code submitted by Zelazny7to print some pseudocode:

我修改了Zelazny7提交的代码,打印了一些伪代码:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

if you call get_code(dt, df.columns)on the same example you will obtain:

如果您调用get_code(dt, df.columns)同一个示例,您将获得:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

回答by Apogentus

Here is a function, printing rules of a scikit-learn decision tree under python 3 and with offsets for conditional blocks to make the structure more readable:

这是一个函数,在 python 3 下打印 scikit-learn 决策树的规则,并带有条件块的偏移量,使结构更具可读性:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

回答by TED Zhao

Codes below is my approach under anaconda python 2.7 plus a package name "pydot-ng" to making a PDF file with decision rules. I hope it is helpful.

下面的代码是我在 anaconda python 2.7 加上包名“pydot-ng”下使用决策规则制作 PDF 文件的方法。我希望它有帮助。

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

a tree graphy show here

树形图展示在这里

回答by Ruslan

Just because everyone was so helpful I'll just add a modification to Zelazny7 and Daniele's beautiful solutions. This one is for python 2.7, with tabs to make it more readable:

仅仅因为每个人都非常有帮助,我只会对 Zelazny7 和 Daniele 的漂亮解决方案进行修改。这是用于 python 2.7 的,带有标签以使其更具可读性:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

回答by Kevin

There is a new DecisionTreeClassifiermethod, decision_path, in the 0.18.0release. The developers provide an extensive (well-documented) walkthrough.

有一种新的DecisionTreeClassifier方法,decision_path0.18.0开始释放。开发人员提供了广泛的(有据可查的)演练

The first section of code in the walkthrough that prints the tree structure seems to be OK. However, I modified the code in the second section to interrogate one sample. My changes denoted with # <--

演练中打印树结构的第一部分代码似乎没问题。但是,我修改了第二部分中的代码来查询一个样本。我的更改表示为# <--

EditThe changes marked by # <--in the code below have since been updated in walkthrough link after the errors were pointed out in pull requests #8653and #10951. It's much easier to follow along now.

编辑# <--在拉取请求#8653#10951中指出错误后,以下代码中标记的更改已在演练链接中更新。现在跟上来要容易得多。

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Change the sample_idto see the decision paths for other samples. I haven't asked the developers about these changes, just seemed more intuitive when working through the example.

更改sample_id以查看其他样本的决策路径。我没有向开发人员询问这些更改,只是在完成示例时看起来更直观。

回答by Arslán

Modified Zelazny7's code to fetch SQL from the decision tree.

修改了 Zelazny7 的代码以从决策树中获取 SQL。

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

回答by horseshoe

This builds on @paulkernfeld 's answer. If you have a dataframe X with your features and a target dataframe y with your resonses and you you want to get an idea which y value ended in which node (and also ant to plot it accordingly) you can do the following:

这建立在@paulkernfeld 的回答之上。如果你有一个包含你的特征的数据框 X 和一个包含你的响应的目标数据框 y 并且你想知道哪个 y 值在哪个节点结束(以及相应地绘制它),你可以执行以下操作:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

not the most elegant version but it does the job...

不是最优雅的版本,但它可以完成工作......