Python 从 Keras 多类模型中获取混淆矩阵

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/50920908/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 19:38:38  来源:igfitidea点击:

Get Confusion Matrix From a Keras Multiclass Model

pythonkeras

提问by ATMA

I am building a multiclass model with Keras.

我正在用 Keras 构建一个多类模型。

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[checkpoint], validation_data=(X_test, y_test))  # starts training

Here is how my test data looks like (it's text data).

这是我的测试数据的样子(它是文本数据)。

X_test
Out[25]: 
array([[621, 139, 549, ...,   0,   0,   0],
       [621, 139, 543, ...,   0,   0,   0]])

y_test
Out[26]: 
array([[0, 0, 1],
       [0, 1, 0]])

After generating predictions...

生成预测后...

predictions = model.predict(X_test)
predictions
Out[27]: 
array([[ 0.29071924,  0.2483743 ,  0.46090645],
       [ 0.29566404,  0.45295066,  0.25138539]], dtype=float32)

I did the following to get the confusion matrix.

我做了以下工作来获得混淆矩阵。

y_pred = (predictions > 0.5)

confusion_matrix(y_test, y_pred)
Traceback (most recent call last):

  File "<ipython-input-38-430e012b2078>", line 1, in <module>
    confusion_matrix(y_test, y_pred)

  File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 252, in confusion_matrix
    raise ValueError("%s is not supported" % y_type)

ValueError: multilabel-indicator is not supported

However, I am getting the above error.

但是,我收到上述错误。

How can I get a confusion matrix when doing a multiclass neural network in Keras?

在 Keras 中进行多类神经网络时如何获得混淆矩阵?

回答by Neabfi

Your input to confusion_matrixmust be an array of int not one hot encodings.

您的输入confusion_matrix必须是一个 int 数组,而不是一个热编码。

matrix = metrics.confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))