Stata:随机森林算法简介与实现

发布时间:2021-01-28 阅读 8769

Stata连享会   主页 || 视频 || 推文 || 知乎

温馨提示: 定期 清理浏览器缓存,可以获得最佳浏览体验。

New! lianxh 命令发布了:
随时搜索连享会推文、Stata 资源,安装命令如下:
. ssc install lianxh
使用详情参见帮助文件 (有惊喜):
. help lianxh

课程详情 https://gitee.com/arlionn/Course   |   lianxh.cn

课程主页 https://gitee.com/arlionn/Course

作者: 李胜胜 (安徽大学)
邮箱: lishengsheng2016@126.com

编者按:本文摘译自下文,特此致谢!

Source: Schonlau M, Zou R Y. The random forest algorithm for statistical learning[J]. The Stata Journal, 2020, 20(1): 3-29. -PDF-


目录


1. 背景介绍

「随机森林」是利用统计或机器学习算法进行预测的方法。随着技术的发展,该方法开始被广泛的应用到社会科学中,并取得了较好预测效果。之所以如此,是因为随机森林考虑到了数据中的非线性关系。

与此同时,随机森林集成的学习算法也非常适合大型数据集。并且,当自变量的数量大于观察值的数量时,线性回归和逻辑回归将无法运行,但随机森林可以有效地回归。

实际上,随机森林是表现最佳的统计学习算法之一。为此,本文将简要介绍随机森林算法、实现命令 rforest、以及案例。

2. 随机森林算法

首先,我们先了解一下基于树的模型,它们是随机森林算法的基础。树的模型是基于特定标准将给定数据集划分为两组,采用递归直到满足预定的停止条件为止。

图 1 说明了一个具有轴对齐边界的二维输入空间的递归分区,每次在平行于一个轴的方向上划分输入空间时,都会进行递归分区。在这里,第一次分裂发生在 x2a2 上。然后,再次划分两个子空间,左分支在 x1a4 上拆分,右分支首先在 x1a1 上拆分,其子分支之一在 x2>a3 上进行拆分。图 2 是图 1 中分区的子空间图形表示。

图 1:二维子空间的递归二进制分区
图 1:二维子空间的递归二进制分区
图 2:决策树的图形表示
图 2:决策树的图形表示

根据划分和停止标准的设置方式,可以分为类任务 (分类结果,例如逻辑回归) 和回归任务 (连续结果) 来设计决策树。对于分类问题和回归问题,用于分割内部节点的预测变量子集取决于预先确定的分割准则,该准则被定义为优化问题。分类问题中常见的分裂准则是熵,在决策树的每个内部节点处,熵的计算公式如下:

其中,c 是唯一类的数量,pi 是每个给定类的先验概率。最大化此值可以在决策树的每个拆分中获取最多的信息。对于回归问题,常用的分割标准是每个内部节点的均方误差。

决策树的一个缺点是容易过度拟合,这意味着模型过于严格地遵循测试数据集的特征,而在新数据集 (即测试数据) 上的性能较差。决策树过度拟合会导致较低的预测精度,也称为泛化精度。

提高泛化精度的一种方法是只考虑观测值的一个子集,并构建很多独立的树。该思想首先由 Ho (1995) 提出,随后随机子空间方法的思想被扩展,并由 Breiman (2001) 正式命名为随机森林。随机森林模型是一种基于集成树的学习算法,该算法对多个单独的树进行平均预测。其中,单独的树是基于引导样本,而不是原始样本构建。这种思想被称为引导聚合或简单的装袋,并可以减少过度拟合。

虽然各个决策树很容易解释,但是由于很多决策树聚合在一起,会导致可解释性在随机森林中丢失。但是作为交换,随机森林通常在预测任务上的表现要好得多。与决策树相比,随机森林算法可以更准确地估计误差率,同时数学家已经证明了误差率总是随着树的数量增加而收敛 (Breiman,2001)。

在训练过程中,随机森林的误差由袋外误差 (OOB) 近似。每棵树都建立在不同的引导程序样本上。其中,每个引导程序样本都将约三分之一的观测值随机排除在外。对于给定的树,这些遗漏的观测值被称为 OOB 样本。因此,查找可能产生低 OOB 误差的参数通常是模型选择和参数调整需要考虑的关键性因素。同时,在随机森林算法中,预测变量的子集 m 的大小对于控制树的最终深度至关重要,该参数在模型选择期间就需要调整。

3. rforest 命令

rforest 命令安装:

*命令安装
ssc install rforest, replace 
*下载地址
*https://fmwww.bc.edu/repec/bocode/r/

rforest 命令语法结构:

rforest depvar indepvars [if] [in] , [ options ]

主要选项介绍:

  • type(string):决策树类型,必须是 class (分类) 或 reg (回归) 之一;

  • iterations(int):设置迭代次数 (树),默认迭代次数为 100;

  • numvars(int):设置随机调查的变量数量,默认为解释变量数开方;

  • depth(int):设置随机森林的最大深度,默认值为 depth(0)

  • lsize(int):设置每片叶子的最小观察数,默认值为 lsize(1)

  • variance(real):设置节点上方差的最小比例,以便在回归树中执行拆分,默认是 variance(1e^(-3)),仅适用于回归;

  • seed(int):设置种子值,默认为 seed(1)

  • numdecimalplaces(int):设置计算精度,默认为numdecimalplaces(5)

4. 信用卡违约案例

Yeh 和 Lien (2009)、Dheeru 和 Karra (2017) 研究了信用卡客户违约概率的预测准确性。该数据集中共有 30000 个观察值,1 个响应变量 (被解释变量),22个解释变量。其中,响应变量是一个二进制变量,0 为 “无违约”,1 为 “违约”。在 22 个解释变量中,10 个是分类变量,包含性别、教育程度、婚姻状况、延迟付款等信息。其余 12 个是连续的解释变量,包含每月账单金额、超 6 个月的付款金额信息等。

在该例中,主要研究影响信用卡违约预测准确性的因素,并将随机森林和 logistic 回归得到的预测精度进行对比。

4.1 模型训练和参数调整

首先,我们需要对数据进行随机排列。其作用在于确保训练数据是随机的。为了获得可重复的结果,需要设置一个种子值。

然后,将数据集分为两个子集,50% 的数据用于训练,50% 的数据用于测试。在小型数据集中,按照 50-50 分割会减少很多训练数据,对于大型的数据集,将其分割为 50-50 并没有什么问题。

最后,由于婚姻状况是使用 0、1、2 和 3 来编码的无序分类信息,因此需要使用 tabulate marriage 命令生成新的变量 marriage_enum

接下来,调整超参数,以找到精度最高的模型。具体来看,通过调整迭代数和变量数,计算出 OOB 预测精度。迭代次数从 10 开始,每次增加 5,直到达到 500。使用 OOB 误差 (针对未包含在子树构造中的训练数据子集进行测试) 和验证误差 (针对测试数据进行测试),以确定最佳的模型。

为了说明 OOB 误差和验证误差如何随着迭代次数的增加而有相似的趋势,需要不断迭代调用随机森林函数,可以通过将这些值与迭代次数相对应来可视化 OOB 误差和验证误差的趋势,如图 3 所示。

*数据下载地址:https://gitee.com/arlionn/data/tree/master/data01/default_news_do

import excel using "default of credit card clients.xls", clear firstrow
rename *, lower  //和原论文对应,将变量字母变小写
label define marriage_label 0 missing 1 married 2 single 3 other
label values marriage marriage_label
tabulate marriage, generate(marriage_enum)

set seed 201807
generate u=uniform()
sort u, stable

*figure out how large the value of iterations need to be
gen out_of_bag_error1 = .
gen validation_error = .
gen iter1 = .
local j = 0
forvalues i = 10(5)500 {
     local j = `j' + 1

     rforest defaultpaymentnextmonth limit_bal sex education marriage_enum* age pay* bill* in 1/1500, type(class) iter(`i') numvars(1)
     qui replace iter1 = `i' in `j'
     qui replace out_of_bag_error1 = `e(OOB_Error)' in `j'
     predict p in 1501/3000
     qui replace validation_error = `e(error_rate)' in `j'
     drop p
}

set scheme s1mono
label var out_of_bag_error1 "Out-of-bag error"
label var iter1 "Iterations"
label var validation_error "Validation error"
scatter out_of_bag_error1 iter1, mcolor(blue) msize(tiny) ||  scatter validation_error iter1, mcolor(red) msize(tiny)

graph save fig3.gph,replace 
图 3:OOB 误差、验证误差与迭代图
图 3:OOB 误差、验证误差与迭代图

从图 3 中可以看出,OOB 误差和验证误差均稳定在 19% 左右。因此,将迭代次数固定为 500 是一个不错的选择。

Note:由于 3 万样本运行非常耗时,这里只选择前三千样本量进行说明。因此,图形和原作者存在差异。

接下来,可以调整参数 numvars()

gen oob_error = .
gen nvars = .
gen val_error = .
local j = 0
forvalues i = 1(1)26{
     local j = `j' + 1
     rforest defaultpaymentnextmonth limit_bal sex ///
     education marriage_enum* age pay* bill* in 1/1500, type(class) ///
     iter(500) numvars(`i')
     qui replace nvars = `i' in `j'
     qui replace oob_error = `e(OOB_Error)' in `j'
     predict p in 1501/3000
     qui replace val_error = `e(error_rate)' in `j'
     drop p
}
label var oob_error "Out-of-bag error"
label var val_error "Validation error"
label var nvars "Number of variables randomly selected at each split"
scatter oob_error nvars, mcolor(blue) msize(tiny) ||   ///
scatter val_error nvars, mcolor(red) msize(tiny)
graph save fig4.gph, replace
图 4:OOB 误差、验证误差与变量数量图
图 4:OOB 误差、验证误差与变量数量图

从图 4 中可以看出,最小误差出现在多少个变量中。以下代码自动查找最小误差和 numvars(),需要 Stata16 版本。

cap frame drop mydata

*only run when tuning is run
frame put val_error nvars, into(mydata)

frame mydata { 
     sort val_error, stable
     local min_val_err = val_error[1]
     local min_nvars = nvars[1]
}

frame drop mydata
local min_val_err : display %9.4f `min_val_err'
local min_nvars = `min_nvars' 
di "Minimum Error: `min_val_err'; Corresponding number of variables `min_nvars'"

*list val_error if nvars==11
*texdoc sum val_error, nolog nooutput  // assuming that 11 is at the minimum
Minimum Error:  0.2047; Corresponding number of variables 11

可以看出,模型设定为 numvars(11),此时得到的最小验证误差为 0.2047。因此,将使用 numvars(11) 作为最终模型。

4.2 最终模型和结果解释

如前部分所示,将超参数的值设置为 iterations(500)numvars(11),经过 500 次迭代后达到了收敛。当然也可以自由地将迭代次数设置得更高。出于谨慎考虑,我们设置了迭代次数 iterations(1000)。以下代码块给出了最终模型的预测误差:

rforest defaultpaymentnextmonth limit_bal sex education marriage_enum* age pay* bill* in 1/1500, type(class) iter(1000) numvars(11)
di e(OOB_Error)
predict prf in 1501/3000
di e(error_rate)
. di e(OOB_Error)
.19133333

. di e(error_rate)
.20933333

最终的 OOB 误差为 18.25%,大于对 15000 个测试观察结果计算出的实际预测误差为 18.24%。从图 3 和图 4 可以看出,针对两个超参数 (即迭代次数和变量数) 进行绘制时,OOB 误差和验证误差具有相同的模式。

Note:推文代码得到最终的 OBB 误差为 20.93%,也大于 20.47%。

怎样确定哪些因素在预测过程中最重要?由于随机森林是一个黑匣子,它们不能提供关于预测是如何实现的,每个预测变量的重要性变量得分只能提供有限的洞察力。以下代码段可以绘制变量的重要性:

matrix importance = e(importance)
svmat importance
gen importid=""

local mynames : rownames importance
local k : word count `mynames'
if `k'>_N {
     set obs `k'
}
forvalues i = 1(1)`k' {
     local aword : word `i' of `mynames'
     local alabel : variable label `aword'
     if ("`alabel'"!="") qui replace importid= "`alabel'" in `i'
     else qui replace importid= "`aword'" in `i'
}

graph hbar (mean) importance, over(importid, sort(1) label(labsize(2))) ///
     ytitle(Importance)
graph save fig5.gph, replace
图 5:预测变量的重要性得分
图 5:预测变量的重要性得分

从图 5 中可以看出,五个最重要的预测指标是基本的人口统计和背景信息,例如性别、教育程度、婚姻状况、以及每月支出限额 (limit_bal)。还可以看到,与其它预测变量相比,编码每月账单金额 (bill amt) 的变量都不是特别重要。但是,令人惊讶的是,在随机森林模型中,每月支出限额 (limit_bal) 是第三重要的预测因子。为此,可以叠加两个每月支出限额的直方图,以获得有关此变量如何影响响应变量的更多解释:


twoway (hist limit_bal if defaultpaymentnextmonth == 0) ///
     (hist limit_bal if defaultpaymentnextmonth == 1,   ///
     fcolor(none) lcolor(black)), ///
     legend(order(1 "no default" 2 "default" ))

graph save fig6.gph,replace
图 6:每月支出限额的直方图
图 6:每月支出限额的直方图

从图 6 的直方图中可以看出,违约的持卡人的月度消费限额通常比不违约的持卡人低。变量重要性衡量了 x 变量对模型的贡献,但取决于 x 变量集。如果排除第一个 x 变量,则与第一个 x 变量相关的另一个 x 变量的重要性将提高。

4.3 与逻辑回归比较

logistic defaultpaymentnextmonth limit_bal sex education  marriage_enum* age pay* bill* in 1/1500
note: marriage_enum1 != 0 predicts failure perfectly
      marriage_enum1 dropped and 5 obs not used

note: marriage_enum4 omitted because of collinearity

Logistic regression                             Number of obs     =      1,495
                                                LR chi2(24)       =     167.93
                                                Prob > chi2       =     0.0000
Log likelihood = -694.93427                     Pseudo R2         =     0.1078

-----------------------------------------------------------------------------------------
defaultpaymentnextmonth | Odds Ratio   Std. Err.      z    P>|z|     [95% Conf. Interval]
------------------------+----------------------------------------------------------------
              limit_bal |   1.000001   6.71e-07     0.92   0.358     .9999993    1.000002
                    sex |   1.074101    .150194     0.51   0.609     .8166198    1.412768
              education |   1.041311   .0990331     0.43   0.670     .8642262    1.254681
         marriage_enum1 |          1  (omitted)
         marriage_enum2 |   1.427221   .7367029     0.69   0.491      .518941    3.925224
         marriage_enum3 |   1.238135   .6475852     0.41   0.683      .444183    3.451233
         marriage_enum4 |          1  (omitted)
                    age |   1.011494   .0085207     1.36   0.175     .9949314    1.028333
                  pay_0 |   1.745042   .1335627     7.27   0.000     1.501953    2.027475
                  pay_2 |   .8536819   .0790831    -1.71   0.088      .711939    1.023645
                  pay_3 |   1.274813   .1295042     2.39   0.017     1.044661    1.555669
                  pay_4 |   .9464986   .1153977    -0.45   0.652     .7453169    1.201985
                  pay_5 |   1.361633   .1695699     2.48   0.013     1.066735    1.738056
                  pay_6 |   .7999327   .0826644    -2.16   0.031     .6532674    .9795259
               pay_amt1 |   .9999748   .0000125    -2.01   0.044     .9999502    .9999994
               pay_amt2 |   .9999818   .0000112    -1.62   0.106     .9999598    1.000004
               pay_amt3 |   .9999945   8.73e-06    -0.63   0.528     .9999774    1.000012
               pay_amt4 |   .9999858   .0000101    -1.41   0.158      .999966    1.000006
               pay_amt5 |   1.000009   5.97e-06     1.47   0.141     .9999971     1.00002
               pay_amt6 |    .999999   4.89e-06    -0.20   0.842     .9999894    1.000009
              bill_amt1 |   .9999931   5.07e-06    -1.37   0.172     .9999831    1.000003
              bill_amt2 |   1.000004   6.49e-06     0.59   0.554     .9999911    1.000017
              bill_amt3 |   1.000006   6.04e-06     1.05   0.294     .9999945    1.000018
              bill_amt4 |    .999998   6.58e-06    -0.30   0.764     .9999851    1.000011
              bill_amt5 |   1.000007   7.89e-06     0.85   0.397     .9999912    1.000022
              bill_amt6 |   .9999919   5.34e-06    -1.51   0.130     .9999815    1.000002
                  _cons |   .1182257   .0820279    -3.08   0.002     .0303478    .4605711
-----------------------------------------------------------------------------------------
predict plogit in 1501/3000
replace plogit = 0 if plogit <= 0.5 & plogit != .
replace plogit = 1 if plogit > 0.5 & plogit != .
gen error = plogit != defaultpaymentnextmonth
sum error in 1501/3000
    Variable |        Obs        Mean    Std. Dev.       Min        Max
-------------+---------------------------------------------------------
       error |      1,500    .2133333    .4097977          0          1

可以看出,使用逻辑回归得到的预测误差为 18.86%,而从随机森林获得最好的误差率为18.25%。两个模型的误差率的差异很小,但对于防止信用卡违约仍可能有意义。

Note:本文第二个案例为「在线新闻受欢迎程度」,并将其与「线性回归」比较,由于与第一个案例较为相似,本文不再赘述,详见 default_news_do

5. 结语

分类和回归示例说明,随机森林模型通常比逻辑回归和线性回归等相应的参数模型具有更高的预测准确性。虽然示例主要集中在调整选项 iterations()numvars() 上,但是在参数期间可以考虑其他超参数,例如最大树深度和叶节点的最小值调整等。

机器学习通常需要将大样本划分为训练集和测试集,但采用低性能计算机时,Stata 运行非常耗费时间。本文案例中,改变了样本量,导致结论与原文作者存在一定差异,但这并不能说明随机森林算法不优越。主要在于,缩小样本后,无法得到更好的训练效果。

6. 参考资料

  • Breiman L. Random forests[J]. Machine learning, 2001, 45(1): 5-32. -PDF-
  • Schonlau M, Zou R Y. The random forest algorithm for statistical learning[J]. The Stata Journal, 2020, 20(1): 3-29. -PDF-

7. 相关推文

Note:产生如下推文列表的命令为:
lianxh 机器 logit, m
安装最新版 lianxh 命令:
ssc install lianxh, replace

相关课程

连享会-直播课 上线了!
http://lianxh.duanshu.com

免费公开课:


课程一览

支持回看

专题 嘉宾 直播/回看视频
最新专题 因果推断, 空间计量,寒暑假班等
数据清洗系列 游万海 直播, 88 元,已上线
研究设计 连玉君 我的特斯拉-实证研究设计-幻灯片-
面板模型 连玉君 动态面板模型-幻灯片-
面板模型 连玉君 直击面板数据模型 [免费公开课,2小时]

Note: 部分课程的资料,PPT 等可以前往 连享会-直播课 主页查看,下载。


关于我们

  • Stata连享会 由中山大学连玉君老师团队创办,定期分享实证分析经验。直播间 有很多视频课程,可以随时观看。
  • 连享会-主页知乎专栏,300+ 推文,实证分析不再抓狂。
  • 公众号关键词搜索/回复 功能已经上线。大家可以在公众号左下角点击键盘图标,输入简要关键词,以便快速呈现历史推文,获取工具软件和数据下载。常见关键词:课程, 直播, 视频, 客服, 模型设定, 研究设计, stata, plus, 绘图, 编程, 面板, 论文重现, 可视化, RDD, DID, PSM, 合成控制法

连享会主页  lianxh.cn
连享会主页 lianxh.cn

连享会小程序:扫一扫,看推文,看视频……

扫码加入连享会微信群,提问交流更方便

✏ 连享会学习群-常见问题解答汇总:
https://gitee.com/arlionn/WD

New! lianxh 命令发布了:
随时搜索连享会推文、Stata 资源,安装命令如下:
. ssc install lianxh
使用详情参见帮助文件 (有惊喜):
. help lianxh