Python 如何绘制混淆矩阵?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/35572000/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 16:37:03  来源:igfitidea点击:

How can I plot a confusion matrix?

pythonmatplotlibmatrixscikit-learntext-classification

提问by minks

I am using scikit-learn for classification of text documents(22000) to 100 classes. I use scikit-learn's confusion matrix method for computing the confusion matrix.

我正在使用 scikit-learn 将文本文档(22000)分类为 100 个类。我使用 scikit-learn 的混淆矩阵方法来计算混淆矩阵。

model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')

This is how my confusion matrix looks like:

这是我的混淆矩阵的样子:

[[3962  325    0 ...,    0    0    0]
 [ 250 2765    0 ...,    0    0    0]
 [   2    8   17 ...,    0    0    0]
 ..., 
 [   1    6    0 ...,    5    0    0]
 [   1    1    0 ...,    0    0    0]
 [   9    0    0 ...,    0    0    9]]

However, I do not receive a clear or legible plot. Is there a better way to do this?

但是,我没有收到清晰易读的情节。有一个更好的方法吗?

回答by bninopaul

enter image description here

在此处输入图片说明

you can use plt.matshow()instead of plt.imshow()or you can use seaborn module's heatmap(see documentation) to plot the confusion matrix

您可以使用plt.matshow()代替plt.imshow()或者您可以使用 seaborn 模块heatmap请参阅文档)来绘制混淆矩阵

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
        [3,31,0,0,0,0,0,0,0,0,0], 
        [0,4,41,0,0,0,0,0,0,0,1], 
        [0,1,0,30,0,6,0,0,0,0,1], 
        [0,0,0,0,38,10,0,0,0,0,0], 
        [0,0,0,3,1,39,0,0,0,0,4], 
        [0,2,2,0,4,1,31,0,0,0,2],
        [0,1,0,0,0,0,0,36,0,2,0], 
        [0,0,0,0,0,0,1,5,37,5,1], 
        [3,0,0,0,0,0,0,0,0,39,0], 
        [0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
                  columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

回答by user1644018

@bninopaul 's answer is not completely for beginners

@bninopaul 的答案并不完全适合初学者

here is the code you can "copy and run"

这是您可以“复制并运行”的代码

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

array = [[13,1,1,0,2,0],
         [3,9,6,0,1,0],
         [0,0,16,2,0,0],
         [0,0,0,13,0,0],
         [0,0,0,0,15,0],
         [0,0,1,0,0,15]]

df_cm = pd.DataFrame(array, range(6), range(6))
# plt.figure(figsize=(10,7))
sn.set(font_scale=1.4) # for label size
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size

plt.show()

result

结果

回答by Wagner Cipriano

IF you want more datain you confusion matrix, including "totals column" and "totals line", and percents(%) in each cell, like matlab default(see image below)

如果您想在混淆矩阵中添加更多数据,包括“总计列”和“总计行”,以及每个单元格中的百分比(%),例如 matlab 默认值(见下图)

enter image description here

在此处输入图片说明

including the Heatmap and other options...

包括热图和其他选项...

You should have fun with the module above, shared in the github ; )

你应该对上面的模块很感兴趣,在 github 中共享;)

https://github.com/wcipriano/pretty-print-confusion-matrix

https://github.com/wcipriano/pretty-print-confusion-matrix



This module can do your task easily and produces the output above with a lot of params to customize your CM: enter image description here

该模块可以轻松完成您的任务,并使用大量参数生成上面的输出以自定义您的 CM: 在此处输入图片说明