在 Python 中使用 Keras 的神经网络中的特征重要性图表

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/45361559/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 16:57:40  来源:igfitidea点击:

Feature Importance Chart in neural network using Keras in Python

pythonneural-networkkeras

提问by andre

I am using python(3.6) anaconda (64 bit) spyder (3.1.2). I already set a neural network model using keras (2.0.6) for a regression problem(one response, 10 variables). I was wondering how can I generate feature importance chart like so:

我正在使用 python(3.6) anaconda (64 位) spyder (3.1.2)。我已经使用 keras (2.0.6) 为回归问题(一个响应,10 个变量)设置了一个神经网络模型。我想知道如何生成特征重要性图表,如下所示:

feature importance chart

特征重要性图表

def base_model():
    model = Sequential()
    model.add(Dense(200, input_dim=10, kernel_initializer='normal', activation='relu'))
    model.add(Dense(1, kernel_initializer='normal'))
    model.compile(loss='mean_squared_error', optimizer = 'adam')
    return model

clf = KerasRegressor(build_fn=base_model, epochs=100, batch_size=5,verbose=0)
clf.fit(X_train,Y_train)

回答by Justin Hallas

I was recently looking for the answer to this question and found something that was useful for what I was doing and thought it would be helpful to share. I ended up using a permutation importancemodule from the eli5 package. It most easily works with a scikit-learn model. Luckily, Keras provides a wrapper for sequential models. As shown in the code below, using it is very straightforward.

我最近在寻找这个问题的答案,发现了一些对我正在做的事情有用的东西,并认为分享会有所帮助。我最终使用了eli5 包中排列重要性模块。它最容易与 scikit-learn 模型配合使用。幸运的是,Keras为顺序模型提供了一个包装器。如下面的代码所示,使用起来非常简单。

from keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
import eli5
from eli5.sklearn import PermutationImportance

def base_model():
    model = Sequential()        
    ...
    return model

X = ...
y = ...

my_model = KerasRegressor(build_fn=base_model, **sk_params)    
my_model.fit(X,y)

perm = PermutationImportance(my_model, random_state=1).fit(X,y)
eli5.show_weights(perm, feature_names = X.columns.tolist())

回答by paolof89

At the moment Keras doesn't provide any functionality to extract the feature importance.

目前 Keras 不提供任何功能来提取特征重要性。

You can check this previous question: Keras: Any way to get variable importance?

您可以查看上一个问题: Keras: Any way to get variable important?

or the related GoogleGroup: Feature importance

或相关的 GoogleGroup:特征重要性

Spoiler: In the GoogleGroup someone announced an open source project to solve this issue..

剧透:在 GoogleGroup 中,有人宣布了一个开源项目来解决这个问题。

回答by jarrettyeo

This is a relatively old post with relatively old answers, so I would like to offer another suggestion of using SHAPto determine feature importance for your Keras models. SHAPoffers support for both 2d and 3d arrays compared to eli5which currently only supports 2d arrays (so if your model uses layers which require 3d input like LSTMor GRU, eli5will not work).

这是一篇相对较旧的帖子,答案相对较旧,因此我想提供另一个建议,SHAP用于确定 Keras 模型的特征重要性。SHAP提供对 2d 和 3d 数组的eli5支持,而目前仅支持 2d 数组(因此,如果您的模型使用需要 3d 输入的图层,例如LSTMGRUeli5将不起作用)。

Here is the linkto an example of how SHAPcan plot the feature importance for your Kerasmodels, but in case it ever becomes broken some sample code and plots are provided below as well (taken from said link):

这是如何绘制模型的特征重要性示例的链接,但如果它被破坏,下面还提供了一些示例代码和图(取自上述链接):SHAPKeras


import shap

# load your data here, e.g. X and y
# create and fit your model here

# load JS visualization code to notebook
shap.initjs()

# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# visualize the first prediction's explanation (use matplotlib=True to avoid Javascript)
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])

shap.summary_plot(shap_values, X, plot_type="bar")

enter image description here

在此处输入图片说明