Python 使用 scikit-learn 时,如何找到我的树分裂的属性?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/20156951/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How do I find which attributes my tree splits on, when using scikit-learn?
提问by tumultous_rooster
I have been exploring scikit-learn, making decision trees with both entropy and gini splitting criteria, and exploring the differences.
我一直在探索 scikit-learn,使用熵和基尼分裂标准制作决策树,并探索差异。
My question, is how can I "open the hood" and find out exactly which attributes the trees are splitting on at each level, along with their associated information values, so I can see where the two criterion make different choices?
我的问题是,我如何“打开引擎盖”并准确找出树在每个级别拆分的属性以及它们的相关信息值,以便我可以看到这两个标准在哪里做出不同的选择?
So far, I have explored the 9 methods outlined in the documentation. They don't appear to allow access to this information. But surely this information is accessible? I'm envisioning a list or dict that has entries for node and gain.
到目前为止,我已经探索了文档中概述的 9 种方法。他们似乎不允许访问此信息。但确定这些信息是可访问的吗?我正在设想一个包含节点和增益条目的列表或字典。
Thanks for your help and my apologies if I've missed something completely obvious.
感谢您的帮助,如果我遗漏了一些非常明显的内容,我深表歉意。
采纳答案by lejlot
Directly from the documentation ( http://scikit-learn.org/0.12/modules/tree.html):
直接来自文档(http://scikit-learn.org/0.12/modules/tree.html):
from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
StringIOmodule is no longer supported in Python3, instead importiomodule.
StringIOPython3 不再支持io模块,而是导入模块。
There is also the tree_attribute in your decision tree object, which allows the direct access to the whole structure.
tree_决策树对象中还有属性,它允许直接访问整个结构。
And you can simply read it
你可以简单地阅读它
clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values
for more details look at the source code of export method
In general you can use the inspectmodule
一般来说,您可以使用该inspect模块
from inspect import getmembers
print( getmembers( clf.tree_ ) )
to get all the object's elements
获取对象的所有元素


回答by Daniel Gibson
If you just want a quick look at which what is going on in the tree, try:
如果您只想快速查看树中发生了什么,请尝试:
zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)
where X is the data frame of independent variables and clf is the decision tree object. Notice that clf.tree_.children_leftand clf.tree_.children_righttogether contain the order that the splits were made (each one of these would correspond to an arrow in the graphviz visualization).
其中X是自变量的数据框,clf是决策树对象。请注意,clf.tree_.children_left和clf.tree_.children_right一起包含进行拆分的顺序(其中每一个都对应于 graphviz 可视化中的一个箭头)。
回答by yzerman
Scikit learn introduced a delicious new method called export_textin version 0.21 (May 2019) to view all the rules from a tree. Documentation here.
Scikit learn 引入了一种export_text在 0.21 版(2019 年 5 月)中调用的美味新方法,可以从树中查看所有规则。文档在这里。
Once you've fit your model, you just need two lines of code. First, import export_text:
拟合模型后,您只需要两行代码。首先,导入export_text:
from sklearn.tree.export import export_text
Second, create an object that will contain your rules. To make the rules look more readable, use the feature_namesargument and pass a list of your feature names. For example, if your model is called modeland your features are named in a dataframe called X_train, you could create an object called tree_rules:
其次,创建一个包含您的规则的对象。为了使规则看起来更具可读性,请使用feature_names参数并传递您的功能名称列表。例如,如果您的模型被调用model并且您的特征在名为 的数据框中命名X_train,则您可以创建一个名为 的对象tree_rules:
tree_rules = export_text(model, feature_names=list(X_train))
Then just print or save tree_rules. Your output will look like this:
然后只需打印或保存tree_rules。您的输出将如下所示:
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1

