Stata:交叉验证之LOOCV方法-looclass命令详解

发布时间:2021-07-05 阅读 114

Stata连享会   主页 || 视频 || 推文 || 知乎 || Bilibili 站

温馨提示: 定期 清理浏览器缓存,可以获得最佳浏览体验。

New! lianxh 命令发布了:
随时搜索推文、Stata 资源。安装:
. ssc install lianxh
详情参见帮助文件 (有惊喜):
. help lianxh
连享会新命令:cnssc, ihelp, rdbalance, gitee, installpkg

课程详情 https://gitee.com/lianxh/Course

课程主页 https://gitee.com/lianxh/Course

⛳ Stata 系列推文:

PDF下载 - 推文合集

作者:王乔 (中南财经政法大学)
邮箱zuelerqiao@foxmail.com


目录


1. 交叉验证

1.1 简介

在实际的训练中,训练结果对于训练集的拟合程度通常比较好 (初始条件敏感),但是对于训练集之外的数据的拟合程度通常就不那么令人满意了。因此,我们通常不会把所有数据都拿来训练,而是分出一部分来 (这一部分不参加训练) 对训练集生成的参数进行测试,相对客观的判断这些参数对训练集之外的数据的符合程度。这种思想称为交叉验证。

交叉验证由于其简洁性和普遍性被认为是一种行之有效的办法,尤其是在可用的数据较少的情况下,通过对数据的有效重复利用,交叉验证充分显示了其在模型选择方面的诸多优点。

交叉验证的具体做法是:将数据集平均分成 n 份,选其中的 n-1 份作为训练集,剩余的 1 份作为验证集,用以上 n 种情况的训练集训练得到模型的参数。如果不做交叉验证,只需训练一次。但是当做交叉验证时,就等于训练了 n 次,训练时间变成了 n 倍。因此当数据集足够大的时候,交叉验证并不常用。

1.2 具体方法

  • 保留交叉验证 (hand-out cross validation):随机将训练样本集分成训练集和交叉验证集,比如分别占 70%、30%,然后使用模型在训练集上学习得到假设,最后使用交叉验证集对假设进行验证,看预测的是否准确,选择均方误差 (MSE) 小的模型;
  • k折交叉验证 (K-fold cross validation):将初始采样分割成 k 个子样本,一个单独的子样本被保留作为验证模型的数据,其他 k-1 个样本用来训练。交叉验证重复 k 次,每个子样本验证一次,平均 k 次的结果或者使用其他结合方式,最终得到一个单一估测。这个方法的优势在于,同时重复运用随机产生的子样本进行训练和验证,每次的结果验证一次; 留一交叉验证 (Leave-one-out Cross-validation):正常训练都会划分训练集和验证集,训练集用来训练模型,而验证集用来评估模型的泛化能力。留一交叉验证是一个极端的例子,如果数据集 D 的大小为 N,那么用 N - 1 条数据进行训练,用剩下的一条数据作为验证。不过,用一条数据作为验证的坏处是,根据训练得到的均方误差与验证结果得到的均方误差相差很大。所以,在留一交叉验证里,每次从数据集 D 中取一组作为验证集,直到所有样本都做过验证集,共计算 N 次,最后对均方误差求平均。

1.3 留一交叉验证详解

LOOCV 方法即留一交叉验证法,一般是将数据集分为训练集和测试集,然后只用一个数据作为测试集,其他的数据都作为训练集,并将此步骤重复 n 次 (n 为数据集的数据数量)。

如上图所示,假设我们现在有 n 个数据组成的数据集,那么 LOOCV 的方法就是每次取出一个数据作为测试集的唯一元素,而其他 n - 1 个数据都作为训练集用于训练模型和调整参数。我们最终训练了 n 个模型,每次都能得到一个均方误差 (MSE)。而 test MSE 则就是将这 n 个均方误差 (MSE) 取平均:

LOOCV 方法有很多优点,首先,它不受测试集合训练集划分方法的影响,因为每一个数据都单独的做过测试集。同时,其用了 n-1 个数据训练模型,也几乎用到了所有的数据,保证了模型的 bias 更小。不过 LOOCV 的缺点也很明显,那就是计算量过于大,是只训练一个模型耗时的 n-1 倍。

2. looclass 命令

2.1 命令安装

looclass 是对具有二值结果的回归模型执行留一交叉验证,然后生成分类度量来帮助确定预测模型的错误率 (或者精确性)。它是一个简单的 n 折交叉验证,其中 n 是数据集中的观察数。依次忽略每个观测结果,对所有剩余的观测结果进行模型估计。然后,计算一个坚持观察的预测值,其准确性取决于对观测结果预测的成功或失败。

. cnssc install looclass, replace

2.2 命令语法

 looclass depvar indepvars [if] [in] [weight] , model(str) [cutoff(#)
                 save figure model_options]
  • depvar:被解释变量;
  • indepvar:解释变量;
  • model(string):指定模型,是必需的。可用的模型有 logitprobitrforestboostsvmachines (必须安装 rforestboostsvmachines 命令才能使用模型);
  • cutoff(#):指定一个值,用于确定一个观察结果是否在分类表中具有预期的积极结果。如果一个观测结果的预测概率为大于 # ,它就被归为积极观测值,默认值是 0.5;
  • save:生成两个变量,fulltest,分别包含对完整数据和测试数据的估计预测;
  • figure:生成一个显示完整数据和测试数据的 ROC 曲线的图表。

3. Stata 实操

我们使用网络数据库中的 Hosmer & Lemeshow 数据,来探究 looclass 命令在 Stata 中的应用。

. *导入数据
. webuse lbw, clear
(Hosmer & Lemeshow data)

由于 looclass 是对具有二值结果的回归模型进行留一交叉验证,因此我们需要查看被解释变量 low 的数据类型。

. tab low

birthweight |
     <2500g |      Freq.     Percent        Cum.
------------+-----------------------------------
          0 |        130       68.78       68.78
          1 |         59       31.22      100.00
------------+-----------------------------------
      Total |        189      100.00

我们选用 logit 模型进行留一交叉验证回归,将数据分为 5 组,保存拟合值,并画出 ROC 曲线。

. looclass low age lwt i.race smoke ptl ht ui, model(logit) fig

首先,呈现 ROC 曲线,如下图:

如果 ROC 越偏向左上角,这说明模型预测效果越好,同时图片下方还计算了 AUC 的值 (也就是横坐标从 0 到 1 曲线下方的面积),全部样本的 AUC 值为 0.7462,测试样本的 AUC 值为 0.6838。

同时还会显示以下表格:

Iterating across (189) observations
----+--- 1 ---+--- 2 ---+--- 3 ---+--- 4 ---+--- 5 
..................................................    50
..................................................   100
..................................................   150
.......................................

Classification Table for Full Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |        21            12  |         33
     -     |        38           118  |        156
-----------+--------------------------+-----------
   Total   |        59           130  |        189

上表是用全部样本 (full data) 作为训练集,然后用全部样本作为测试集,最后得到的统计结果。其中,坐标 + 表示拟合值等于 1 的数目,- 表示拟合值等于 0 的数目;坐标 D 代表真实等于 1 的数目,~D 表示真实等于 0 的数目。

Classification Table for Test Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |        18            18  |         36
     -     |        41           112  |        153
-----------+--------------------------+-----------
   Total   |        59           130  |        189

由于是使用留一法进行交叉检验,所以上表是只使用原本样本中的一项来当做验证资料,而剩余的则留下来当做训练资料。这个步骤一直持续到每个样本都被当做一次验证资料为止。其中,坐标 + 表示拟合值等于 1 的数目,- 表示拟合值等于 0 的数目;坐标 D 代表真实等于 1 的数目,~D 表示真实等于 0 的数目。

Classified + if predicted Pr(D) >= .5
True D defined as  != 0
                                            Full         Test
----------------------------------------------------------------
Sensitivity                     Pr( +| D)   35.59%       30.51%
Specificity                     Pr( -|~D)   90.77%       86.15%
Positive predictive value       Pr( D| +)   63.64%       50.00%
Negative predictive value       Pr(~D| -)   75.64%       73.20%
----------------------------------------------------------------
False + rate for true ~D        Pr( +|~D)    9.23%       13.85%
False - rate for true D         Pr( -| D)   64.41%       69.49%
False + rate for classified +   Pr(~D| +)   36.36%       50.00%
False - rate for classified -   Pr( D| -)   24.36%       26.80%
----------------------------------------------------------------
Correctly classified                        73.54%       68.78%
----------------------------------------------------------------
ROC area                                    0.7462       0.6838
----------------------------------------------------------------
p-value for Full vs Test ROC areas                       0.0000
----------------------------------------------------------------

对上表的解释说明如下:

  • 第一列为变量名;
  • 第二列为变量的含义:比如 Pr( +| D) 代表在真实值为 1 的集合中,预测也为 1 的样本数量占真实值为 1 集合的比率,Pr( -|~D) 代表真实值为 0 的集合中,预测为 0 的样本占真实值为 0 集合中样本数目的比率;
  • 第三列为用全部数据训练模型得到的变量值;
  • 第四列为用交叉验证法训练模型得到的变量值;
  • Correctly classified 表示预测的准确率,也就是预测对的样本数占所有样本的数量;
  • ROC area 表示 AUC 值,也就是 ROC 曲线下方的面积。

对各种评价指标的计算及定义具体参考「机器学习:准确率 (Precision) 、召回率 (Recall) 、F 值 (F-Measure) 、ROC 曲线、PR 曲线」

接下来,我们对留一交叉验证加一个限制条件。第二个示例与上面的示例基本相同,但是该回归增加了条件,将样本限制为 age 小于 30 ,并对完整数据和测试数据的预测值进行了保存。

. looclass low age lwt i.race smoke ptl ht ui if age<30, model(logit) fig save 

增加了限制条件后,相应的留一交叉结果如下:

 Classification Table for Full Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |        21            10  |         31
     -     |        34            97  |        131
-----------+--------------------------+-----------
   Total   |        55           107  |        162
Classification Table for Test Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |        21            17  |         38
     -     |        34            90  |        124
-----------+--------------------------+-----------
   Total   |        55           107  |        162
Classified + if predicted Pr(D) >= .5
True D defined as  != 0
                                            Full         Test
----------------------------------------------------------------
Sensitivity                     Pr( +| D)   38.18%       38.18%
Specificity                     Pr( -|~D)   90.65%       84.11%
Positive predictive value       Pr( D| +)   67.74%       55.26%
Negative predictive value       Pr(~D| -)   74.05%       72.58%
----------------------------------------------------------------
False + rate for true ~D        Pr( +|~D)    9.35%       15.89%
False - rate for true D         Pr( -| D)   61.82%       61.82%
False + rate for classified +   Pr(~D| +)   32.26%       44.74%
False - rate for classified -   Pr( D| -)   25.95%       27.42%
----------------------------------------------------------------
Correctly classified                        72.84%       68.52%
----------------------------------------------------------------
ROC area                                    0.7316       0.6605
----------------------------------------------------------------
p-value for Full vs Test ROC areas                       0.0000
----------------------------------------------------------------

在使用 looclass 进行的留一交叉验证法时,我们除了可以增加限制条件之外,还可以使用 model 选项对模型进行指定。上面两条命令都指定了 logit 模型,下面我们将指定 svmachines 模型,并对完整数据和测试数据的预测值进行保存。

. *安装命令
. cnssc install svmachines, replace
. looclass low age lwt i.race smoke ptl ht ui, model(svmachines) fig save 

选用 svmachines 模型进行留一交叉验证回归的结果如下:

Classification Table for Full Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |         0             0  |          0
     -     |        59           130  |        189
-----------+--------------------------+-----------
   Total   |        59           130  |        189
Classification Table for Test Data:

              -------- True --------
Classified |         D            ~D  |      Total
-----------+--------------------------+-----------
     +     |         0             0  |          0
     -     |        59           130  |        189
-----------+--------------------------+-----------
   Total   |        59           130  |        189
Classified + if predicted Pr(D) >= .5
True D defined as  != 0
                                            Full         Test
----------------------------------------------------------------
Sensitivity                     Pr( +| D)    0.00%        0.00%
Specificity                     Pr( -|~D)  100.00%      100.00%
Positive predictive value       Pr( D| +)       .%           .%
Negative predictive value       Pr(~D| -)   68.78%       68.78%
----------------------------------------------------------------
False + rate for true ~D        Pr( +|~D)    0.00%        0.00%
False - rate for true D         Pr( -| D)  100.00%      100.00%
False + rate for classified +   Pr(~D| +)       .%           .%
False - rate for classified -   Pr( D| -)   31.22%       31.22%
----------------------------------------------------------------
Correctly classified                        68.78%       68.78%
----------------------------------------------------------------
ROC area                                    0.0119       0.3288
----------------------------------------------------------------
p-value for Full vs Test ROC areas                       0.0000
----------------------------------------------------------------

当然,除了可以使用 logit 模型和 svmachines 模型外,我们还可以使用 model (string) 选项来设置其他的模型,例如 logitprobitrforestboostsvmachines

进行留一交叉验证后,还可使用 cutpt 来估计完整数据和测试数据的 “最佳” 切割点。

. *安装命令
. cnssc install cutpt, replace
. cutpt low full, youden

对于全样本 (full data) 进行 cutpt 的结果如下,使用的切割方法是 Youden,经验上最佳的切割点是 0.3018164,切点处 ROC 曲线下的面积为 0.69。

Empirical cutpoint estimation
Method:                                Youden
Reference variable:                    low (0=neg, 1=pos)
Classification variable:               full
Empirical optimal cutpoint:            .30158164
Youden index (J):                      0.372
SE(J):                                 0.0751
Sensitivity at cutpoint:               0.75
Specificity at cutpoint:               0.63
Area under ROC curve at cutpoint:      0.69

下面是基于测试样本 (test data) 的 cutpt

. cutpt low test, youden

对于 test 样本进行 cutpt 的结果如下,使用的切割方法是 Youden,经验上最佳的切割点是 0.41182332,切点处 ROC 曲线下的面积为 0.63。

Empirical cutpoint estimation
Method:                                Youden
Reference variable:                    low (0=neg, 1=pos)
Classification variable:               test
Empirical optimal cutpoint:            .41182332
Youden index (J):                      0.267
SE(J):                                 0.0785
Sensitivity at cutpoint:               0.49
Specificity at cutpoint:               0.78
Area under ROC curve at cutpoint:      0.63

4. 参考文献

  • Linden A. LOOCLASS: stata module for generating classification statistics of leave-One-Out cross-validation for binary outcomes[J]. 2020. -Link-
  • LOOCV - Leave-One-Out-Cross-Validation 留一交叉验证 -Link-
  • 机器学习:准确率(Precision)、召回率(Recall)、F值(F-Measure)、ROC曲线、PR曲线 -Link-
  • Cross-Validation(交叉验证)详解 -Link-

5. 相关推文

Note:产生如下推文列表的 Stata 命令为:
lianxh 交叉验证 刀切法 bootstrap logit
安装最新版 lianxh 命令:
ssc install lianxh, replace