Python 使用 scikit-learn 时,如何找到我的树分裂的属性?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/20156951/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-18 19:41:46  来源:igfitidea点击:

How do I find which attributes my tree splits on, when using scikit-learn?

pythonmachine-learningscikit-learndecision-tree

提问by tumultous_rooster

I have been exploring scikit-learn, making decision trees with both entropy and gini splitting criteria, and exploring the differences.

我一直在探索 scikit-learn,使用熵和基尼分裂标准制作决策树,并探索差异。

My question, is how can I "open the hood" and find out exactly which attributes the trees are splitting on at each level, along with their associated information values, so I can see where the two criterion make different choices?

我的问题是,我如何“打开引擎盖”并准确找出树在每个级别拆分的属性以及它们的相关信息值,以便我可以看到这两个标准在哪里做出不同的选择?

So far, I have explored the 9 methods outlined in the documentation. They don't appear to allow access to this information. But surely this information is accessible? I'm envisioning a list or dict that has entries for node and gain.

到目前为止,我已经探索了文档中概述的 9 种方法。他们似乎不允许访问此信息。但确定这些信息是可访问的吗?我正在设想一个包含节点和增益条目的列表或字典。

Thanks for your help and my apologies if I've missed something completely obvious.

感谢您的帮助,如果我遗漏了一些非常明显的内容,我深表歉意。

采纳答案by lejlot

Directly from the documentation ( http://scikit-learn.org/0.12/modules/tree.html):

直接来自文档(http://scikit-learn.org/0.12/modules/tree.html):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)

StringIOmodule is no longer supported in Python3, instead import iomodule.

StringIOPython3 不再支持io模块,而是导入模块。

There is also the tree_attribute in your decision tree object, which allows the direct access to the whole structure.

tree_决策树对象中还有属性,它允许直接访问整个结构。

And you can simply read it

你可以简单地阅读它

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values

for more details look at the source code of export method

更多细节查看导出方法源代码

In general you can use the inspectmodule

一般来说,您可以使用该inspect模块

from inspect import getmembers
print( getmembers( clf.tree_ ) )

to get all the object's elements

获取对象的所有元素

Decision tree visualization from sklearn docs

来自 sklearn 文档的决策树可视化

回答by Daniel Gibson

If you just want a quick look at which what is going on in the tree, try:

如果您只想快速查看树中发生了什么,请尝试:

zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)

where X is the data frame of independent variables and clf is the decision tree object. Notice that clf.tree_.children_leftand clf.tree_.children_righttogether contain the order that the splits were made (each one of these would correspond to an arrow in the graphviz visualization).

其中X是自变量的数据框,clf是决策树对象。请注意,clf.tree_.children_leftclf.tree_.children_right一起包含进行拆分的顺序(其中每一个都对应于 graphviz 可视化中的一个箭头)。

回答by yzerman

Scikit learn introduced a delicious new method called export_textin version 0.21 (May 2019) to view all the rules from a tree. Documentation here.

Scikit learn 引入了一种export_text在 0.21 版(2019 年 5 月)中调用的美味新方法,可以从树中查看所有规则。文档在这里

Once you've fit your model, you just need two lines of code. First, import export_text:

拟合模型后,您只需要两行代码。首先,导入export_text

from sklearn.tree.export import export_text

Second, create an object that will contain your rules. To make the rules look more readable, use the feature_namesargument and pass a list of your feature names. For example, if your model is called modeland your features are named in a dataframe called X_train, you could create an object called tree_rules:

其次,创建一个包含您的规则的对象。为了使规则看起来更具可读性,请使用feature_names参数并传递您的功能名称列表。例如,如果您的模型被调用model并且您的特征在名为 的数据框中命名X_train,则您可以创建一个名为 的对象tree_rules

tree_rules = export_text(model, feature_names=list(X_train))

Then just print or save tree_rules. Your output will look like this:

然后只需打印或保存tree_rules。您的输出将如下所示:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1