Python 如何从 scikit-learn 解释决策树
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/23557545/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
how to explain the decision tree from scikit-learn
提问by Student Hyman
I have two problems with understanding the result of decision tree from scikit-learn. For example, this is one of my decision trees:
我在理解 scikit-learn 决策树的结果时有两个问题。例如,这是我的决策树之一:
My question is that how I can use the tree?
我的问题是我如何使用这棵树?
The first question is that: if a sample satisfied the condition, then it goes to the LEFTbranch (if exists), otherwise it goes RIGHT. In my case, if a sample with X[7] > 63521.3984. Then the sample will go to the green box. Correct?
第一个问题是:如果样本满足条件,则它转到LEFT分支(如果存在),否则它转到RIGHT。就我而言,如果 X[7] > 63521.3984 的样本。然后样品将进入绿色框。正确的?
The second question is that: when a sample reaches the leaf node, how can I know which category it belongs? In this example, I have three categories to classify. In the red box, there are 91, 212, and 113 samples are satisfied the condition, respectively. But how can I decide the category? I know there is a function clf.predict(sample)to tell the category. Can I do that from the graph??? Many thanks.
第二个问题是:当一个样本到达叶子节点时,我如何知道它属于哪个类别?在这个例子中,我有三个类别要分类。在红色框中,分别有 91、212 和 113 个样本满足条件。但是我如何确定类别?我知道有一个函数 clf.predict(sample)可以告诉类别。我可以从图表中做到这一点吗???非常感谢。
采纳答案by BrenBarn
The value
line in each box is telling you how many samples at that node fall into each category, in order. That's why, in each box, the numbers in value
add up to the number shown in sample
. For instance, in your red box, 91+212+113=416. So this means if you reach this node, there were 91 data points in category 1, 212 in category 2, and 113 in category 3.
value
每个框中的行告诉您该节点有多少样本按顺序属于每个类别。这就是为什么在每个框中, 中的数字value
加起来为 中显示的数字sample
。例如,在您的红色框中,91+212+113=416。所以这意味着如果你到达这个节点,类别 1 中有 91 个数据点,类别 2 中有 212 个数据点,类别 3 中有 113 个数据点。
If you were going to predict the outcome for a new data point that reached that leaf in the decision tree, you would predict category 2, because that is the most common category for samples at that node.
如果您要预测到达决策树中该叶子的新数据点的结果,您将预测类别 2,因为这是该节点样本的最常见类别。
回答by user3784777
According to the book "Learning scikit-learn: Machine Learning in Python", The decision tree represents a series of decisions based on the training data.
根据《Learning scikit-learn: Machine Learning in Python》一书,决策树表示基于训练数据的一系列决策。
!(http://i.imgur.com/vM9fJLy.png)
!( http://i.imgur.com/vM9fJLy.png)
To classify an instance, we should answer the question at each node. For example, Is sex<=0.5? (are we talking about a woman?). If the answer is yes, you go to the left child node in the tree; otherwise you go to the right child node. You keep answering questions (was she in the third class?, was she in the first class?, and was she below 13 years old?), until you reach a leaf. When you are there, the prediction corresponds to the target class that has most instances.
为了对实例进行分类,我们应该在每个节点上回答问题。例如,性别<=0.5吗?(我们在谈论一个女人吗?)。如果答案是肯定的,则转到树中的左子节点;否则你去右子节点。你一直在回答问题(她在第三班吗?她在第一班吗?她是 13 岁以下吗?),直到你找到一片叶子。当您在那里时,预测对应于具有最多实例的目标类。
回答by MyopicVisage
First question:Yes, your logic is correct. The left node is True and the right node is False. This can be counter-intuitive; true can equate to a smaller sample.
第一个问题:是的,你的逻辑是正确的。左节点为真,右节点为假。这可能违反直觉;true 可以等同于较小的样本。
Second question:This problem is best resolved by visualizing the tree as a graph with pydotplus. The 'class_names' attribute of tree.export_graphviz() will add a class declaration to the majority class of each node. Code is executed in an iPython notebook.
第二个问题:这个问题最好通过使用 pydotplus 将树可视化为图形来解决。tree.export_graphviz() 的 'class_names' 属性将为每个节点的多数类添加一个类声明。代码在 iPython notebook 中执行。
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color )
#
Image(graph2.create_png() )
As for determining the class at the leaf, your example doesn't have leaves with a single class, as the iris data set does. This is common and may require over-fitting the model to attain such an outcome. A discrete distribution of classes is best result for many cross-validated models.
至于确定叶子上的类,您的示例没有像 iris 数据集那样具有单个类的叶子。这很常见,可能需要过度拟合模型才能获得这样的结果。对于许多交叉验证的模型来说,类的离散分布是最好的结果。
Enjoy the code!
享受代码!
回答by Roo
Add feature_names=X.columns to tree.export_graphviz where X is the training data.
将 feature_names=X.columns 添加到 tree.export_graphviz,其中 X 是训练数据。
My code is as follows
我的代码如下
with open("lectureGini.txt", "w") as f:
f = tree.export_graphviz(lectureGini, out_file=f,feature_names=X.columns)
# copy contents of file LectureGini.txt into WebGraphviz - http://webgraphviz.com/
lectureGini is the output from my DecisionTreeClassifier
talkGini 是我的 DecisionTreeClassifier 的输出
This is a simple method I discovered that could be added to all the web examples of the Gini Index I had researched. All the web examples explained the method really well but none showed how to find the categories. I don't have Graphviz installed yet so am exporting a text file from jupyter and copying the text into the Webgraphwiz
这是我发现的一种简单方法,可以添加到我研究过的 Gini 指数的所有网络示例中。所有的网络示例都很好地解释了该方法,但没有一个展示如何找到类别。我还没有安装 Graphviz,所以我从 jupyter 导出文本文件并将文本复制到 Webgraphwiz