Python 如何绘制混淆矩阵?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/35572000/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How can I plot a confusion matrix?
提问by minks
I am using scikit-learn for classification of text documents(22000) to 100 classes. I use scikit-learn's confusion matrix method for computing the confusion matrix.
我正在使用 scikit-learn 将文本文档(22000)分类为 100 个类。我使用 scikit-learn 的混淆矩阵方法来计算混淆矩阵。
model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')
This is how my confusion matrix looks like:
这是我的混淆矩阵的样子:
[[3962 325 0 ..., 0 0 0]
[ 250 2765 0 ..., 0 0 0]
[ 2 8 17 ..., 0 0 0]
...,
[ 1 6 0 ..., 5 0 0]
[ 1 1 0 ..., 0 0 0]
[ 9 0 0 ..., 0 0 9]]
However, I do not receive a clear or legible plot. Is there a better way to do this?
但是,我没有收到清晰易读的情节。有一个更好的方法吗?
回答by bninopaul
you can use plt.matshow()
instead of plt.imshow()
or you can use seaborn module's heatmap
(see documentation) to plot the confusion matrix
您可以使用plt.matshow()
代替plt.imshow()
或者您可以使用 seaborn 模块heatmap
(请参阅文档)来绘制混淆矩阵
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3],
[3,31,0,0,0,0,0,0,0,0,0],
[0,4,41,0,0,0,0,0,0,0,1],
[0,1,0,30,0,6,0,0,0,0,1],
[0,0,0,0,38,10,0,0,0,0,0],
[0,0,0,3,1,39,0,0,0,0,4],
[0,2,2,0,4,1,31,0,0,0,2],
[0,1,0,0,0,0,0,36,0,2,0],
[0,0,0,0,0,0,1,5,37,5,1],
[3,0,0,0,0,0,0,0,0,39,0],
[0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)
回答by user1644018
@bninopaul 's answer is not completely for beginners
@bninopaul 的答案并不完全适合初学者
here is the code you can "copy and run"
这是您可以“复制并运行”的代码
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[13,1,1,0,2,0],
[3,9,6,0,1,0],
[0,0,16,2,0,0],
[0,0,0,13,0,0],
[0,0,0,0,15,0],
[0,0,1,0,0,15]]
df_cm = pd.DataFrame(array, range(6), range(6))
# plt.figure(figsize=(10,7))
sn.set(font_scale=1.4) # for label size
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size
plt.show()
回答by Wagner Cipriano
IF you want more datain you confusion matrix, including "totals column" and "totals line", and percents(%) in each cell, like matlab default(see image below)
如果您想在混淆矩阵中添加更多数据,包括“总计列”和“总计行”,以及每个单元格中的百分比(%),例如 matlab 默认值(见下图)
including the Heatmap and other options...
包括热图和其他选项...
You should have fun with the module above, shared in the github ; )
你应该对上面的模块很感兴趣,在 github 中共享;)
https://github.com/wcipriano/pretty-print-confusion-matrix
https://github.com/wcipriano/pretty-print-confusion-matrix
This module can do your task easily and produces the output above with a lot of params to customize your CM: