Python scikit-learn 中的 class_weight 参数是如何工作的?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/30972029/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How does the class_weight parameter in scikit-learn work?
提问by kilgoretrout
I am having a lot of trouble understanding how the class_weight
parameter in scikit-learn's Logistic Regression operates.
我在理解class_weight
scikit-learn 的逻辑回归中的参数如何运作时遇到了很多麻烦。
The Situation
情况
I want to use logistic regression to do binary classification on a very unbalanced data set. The classes are labelled 0 (negative) and 1 (positive) and the observed data is in a ratio of about 19:1 with the majority of samples having negative outcome.
我想使用逻辑回归对非常不平衡的数据集进行二元分类。这些类别被标记为 0(负)和 1(正),观察到的数据的比例约为 19:1,大多数样本具有负结果。
First Attempt: Manually Preparing Training Data
第一次尝试:手动准备训练数据
I split the data I had into disjoint sets for training and testing (about 80/20). Then I randomly sampled the training data by hand to get training data in different proportions than 19:1; from 2:1 -> 16:1.
我将我拥有的数据拆分为不相交的集合以进行训练和测试(大约 80/20)。然后我手工随机抽取训练数据,得到不同比例的训练数据,而不是19:1;从 2:1 -> 16:1。
I then trained logistic regression on these different training data subsets and plotted recall (= TP/(TP+FN)) as a function of the different training proportions. Of course, the recall was computed on the disjoint TEST samples which had the observed proportions of 19:1. Note, although I trained the different models on different training data, I computed recall for all of them on the same (disjoint) test data.
然后我对这些不同的训练数据子集进行逻辑回归训练,并绘制召回率 (= TP/(TP+FN)) 作为不同训练比例的函数。当然,召回率是根据观察到的比例为 19:1 的不相交 TEST 样本计算得出的。请注意,虽然我在不同的训练数据上训练了不同的模型,但我在相同(不相交)的测试数据上计算了所有模型的召回率。
The results were as expected: the recall was about 60% at 2:1 training proportions and fell off rather fast by the time it got to 16:1. There were several proportions 2:1 -> 6:1 where the recall was decently above 5%.
结果正如预期的那样:在 2:1 的训练比例下,召回率约为 60%,当达到 16:1 时,召回率下降得相当快。有几个比例为 2:1 -> 6:1,其中召回率高于 5%。
Second Attempt: Grid Search
第二次尝试:网格搜索
Next, I wanted to test different regularization parameters and so I used GridSearchCV and made a grid of several values of the C
parameter as well as the class_weight
parameter. To translate my n:m proportions of negative:positive training samples into the dictionary language of class_weight
I thought that I just specify several dictionaries as follows:
接下来,我想测试不同的正则化参数,因此我使用了 GridSearchCV 并制作了一个包含多个C
参数值和参数的网格class_weight
。将我的 n:m 负:正训练样本比例翻译成class_weight
我认为我只是指定几个字典如下的字典语言:
{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 } #expected 4:1
and I also included None
and auto
.
我还包括None
和auto
。
This time the results were totally wacked. All my recalls came out tiny (< 0.05) for every value of class_weight
except auto
. So I can only assume that my understanding of how to set the class_weight
dictionary is wrong. Interestingly, the class_weight
value of 'auto' in the grid search was around 59% for all values of C
, and I guessed it balances to 1:1?
这一次的结果完全出乎意料。对于class_weight
except 的每个值,我的所有回忆都很小(< 0.05)auto
。所以我只能假设我对如何设置class_weight
字典的理解是错误的。有趣的是,class_weight
对于 的所有值,网格搜索中 'auto'的值约为 59% C
,我猜它平衡为 1:1?
My Questions
我的问题
How do you properly use
class_weight
to achieve different balances in training data from what you actually give it? Specifically, what dictionary do I pass toclass_weight
to use n:m proportions of negative:positive training samples?If you pass various
class_weight
dictionaries to GridSearchCV, during cross-validation will it rebalance the training fold data according to the dictionary but use the true given sample proportions for computing my scoring function on the test fold? This is critical since any metric is only useful to me if it comes from data in the observed proportions.What does the
auto
value ofclass_weight
do as far as proportions? I read the documentation and I assume "balances the data inversely proportional to their frequency" just means it makes it 1:1. Is this correct? If not, can someone clarify?
您如何正确使用
class_weight
以实现与实际提供的训练数据不同的平衡?具体来说,我通过什么字典class_weight
来使用 n:m 负数:正数训练样本的比例?如果您将各种
class_weight
字典传递给 GridSearchCV,在交叉验证期间,它是否会根据字典重新平衡训练折叠数据,但使用真实的给定样本比例来计算我对测试折叠的评分函数?这很关键,因为任何指标只有在来自观察比例的数据时才对我有用。就比例而言,
auto
值的class_weight
作用是什么?我阅读了文档,我假设“平衡数据与它们的频率成反比”只是意味着它是 1:1。这样对吗?如果没有,有人可以澄清吗?
采纳答案by Andreas Mueller
First off, it might not be good to just go by recall alone. You can simply achieve a recall of 100% by classifying everything as the positive class. I usually suggest using AUC for selecting parameters, and then finding a threshold for the operating point (say a given precision level) that you are interested in.
首先,仅靠回忆可能不好。通过将所有内容归类为正类,您可以简单地实现 100% 的召回率。我通常建议使用 AUC 来选择参数,然后找到您感兴趣的操作点(比如给定的精度水平)的阈值。
For how class_weight
works: It penalizes mistakes in samples of class[i]
with class_weight[i]
instead of 1. So higher class-weight means you want to put more emphasis on a class. From what you say it seems class 0 is 19 times more frequent than class 1. So you should increase the class_weight
of class 1 relative to class 0, say {0:.1, 1:.9}.
If the class_weight
doesn't sum to 1, it will basically change the regularization parameter.
对于如何class_weight
作品:它惩罚失误的样品class[i]
用class_weight[i]
的,而不是1。所以高类的重量意味着要更多地强调的一类。从你所说的看来,0 类的频率是 1 类的 19 倍。所以你应该增加class_weight
1 类相对于 0 类的 ,比如 {0:.1, 1:.9}。如果class_weight
总和不为 1,则基本上会更改正则化参数。
For how class_weight="auto"
works, you can have a look at this discussion.
In the dev version you can use class_weight="balanced"
, which is easier to understand: it basically means replicating the smaller class until you have as many samples as in the larger one, but in an implicit way.
有关如何class_weight="auto"
工作,您可以查看此讨论。在开发版本中,您可以使用class_weight="balanced"
,这更容易理解:它基本上意味着复制较小的类,直到您拥有与较大类中一样多的样本,但是以隐式的方式。
回答by citynorman
The first answer is good for understanding how it works. But I wanted to understand how I should be using it in practice.
第一个答案有助于理解它是如何工作的。但我想了解我应该如何在实践中使用它。
SUMMARY
概括
- for moderately imbalanced data WITHOUT noise, there is not much of a difference in applying class weights
- for moderately imbalanced data WITH noise and strongly imbalanced, it is better to apply class weights
- param
class_weight="balanced"
works decent in the absence of you wanting to optimize manually - with
class_weight="balanced"
you capture more true events (higher TRUE recall) but also you are more likely to get false alerts (lower TRUE precision)- as a result, the total % TRUE might be higher than actual because of all the false positives
- AUC might misguide you here if the false alarms are an issue
- no need to change decision threshold to the imbalance %, even for strong imbalance, ok to keep 0.5 (or somewhere around that depending on what you need)
- 对于没有噪声的中等不平衡数据,应用类权重没有太大区别
- 对于有噪声和严重不平衡的中度不平衡数据,最好应用类权重
- param
class_weight="balanced"
在您不想手动优化的情况下工作得很好 - 随着
class_weight="balanced"
您捕获更多真实事件(更高的 TRUE 召回率),您也更有可能获得错误警报(更低的 TRUE 精度)- 因此,由于所有误报,总 % TRUE 可能高于实际
- 如果误报是一个问题,AUC 可能会在这里误导您
- 无需将决策阈值更改为不平衡百分比,即使对于严重不平衡,也可以保持 0.5(或根据您的需要而定)
CODE
代码
# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd
# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2
LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?
roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same
# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06
LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last
roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last
print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%
print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE