温馨提示: 定期 清理浏览器缓存,可以获得最佳浏览体验。
作者: 田原 (北京交通大学)
E-mail: godfreytian@163.com
支持向量机:Stata 和 Python 实现
目录
支持向量机(support vector machines, SVM)的基本模型是定义在特征空间上的间隔最大的线性分类器,间隔最大使它有别于感知机;SVM 还包括核技巧,这使它成为实质上的非线性分类器。SVM 的学习策略就是间隔最大化,可形式化为一个求解凸二次规划的问题,也等价于正则化的合页损失函数的最小化问题。SVM 的学习算法就是求解凸二次规划的最优化算法。
了解 SVM 算法之前,首先需要了解一下线性分类器这个概念。比如给定一系列的数据样本,每个样本都有对应的一个标签。为了使得描述更加直观,我们以二维平面为例,将特征向量映射为空间中的一些点,就是如下图的实心点和空心点,它们属于不同的两类。
那么 SVM 的目的就是想要画出一条线,以“最好地”区分这两类点,以至如果以后有了新的点,这条线也能做出很好的分类,也就是说要使条直线能够达到最好的泛化能力。那么能够画出多少条线对样本点进行区分?线是有无数条可以画的,区别就在于效果好不好。比如下图中绿线就不好,蓝线一般,红线看起来会更好。我们所希望找到的这条效果最好的线叫作划分超平面,它是一个能使两类之间的空间大小最大的一个超平面。
为什么要叫作“超平面”呢?因为样本的特征很可能是高维的,此时样本空间的划分就需要“超平面”。这个超平面在二维平面上看到的就是一条直线,在三维空间中就是一个平面,因此,我们把这个划分数据的决策边界统称为超平面。离这个超平面最近的点就叫做支持向量。支持向量机就是要使超平面和支持向量之间的间隔尽可能的大,这样超平面才可以将两类样本准确的分开,而保证间隔尽可能的大就是保证分类器误差尽可能的小。
那么画线的标准是什么?如何才能画出效果好的线?SVM 将会寻找可以区分两个类别并且能使边际(margin)最大的超平面(hyper plane),即划分超平面。边际就是分类的超平面和对应类别最近的样本点之间的距离。
如上图所示,所有坐落在边际超平面上的点被称为支持向量(support vectors),它们是用来定义边际的,是距离划分超平面最近的点。 支持向量支撑了边际区域,并且用于建立划分超平面。值得注意是,支持向量每一侧可能不止一个,有可能一侧有多个点都落在边际平面上。
如下图所示,SVM的目标就是使这个边际(
SVM 所训练出的模型,其算法复杂度是由支持向量的个数决定的,而不是由数据的维度决定的,当然支持向量的个数多少也和训练集的大小有关。因此 SVM 可以一定程度上避免过拟合(overfitting)的现象。 SVM 训练出来的模型完全依赖于支持向量,即使训练集里面所有非支持向量的点都被去除,重复训练过程,结果依然会得到完全相同的模型。 若一个 SVM 训练得出的支持向量个数比较少,那么 SVM 训练出的模型具有较好的泛化能力。
SVM 所训练出的模型,其算法复杂度是由支持向量的个数决定的,而不是由数据的维度决定的,当然支持向量的个数多少也和训练集的大小有关。因此 SVM 可以一定程度上避免过拟合(overfitting)的现象。 SVM 训练出来的模型完全依赖于支持向量,即使训练集里面所有非支持向量的点都被去除,重复训练过程,结果依然会得到完全相同的模型。 若一个 SVM 训练得出的支持向量个数比较少,那么 SVM 训练出的模型具有较好的泛化能力。
SVM的目标就是找出边际最大的超平面,那么如何找出这个最大边际的超平面呢(MMH)?利用 Karush-Kuhn-Tucker(KKT)条件和拉格朗日公式,可以推出 MMH 可以被表示为以下决定边界(decision boundary)
该方程就表示边际最大化的划分超平面。
每当有新的测试样本
在线性 SVM 中转化为最优化问题时求解的公式计算是以内积(dot product)形式出现的,假设原始的数据是非线性的,我们通过一个映射
这核所对映的映射是可以表示出来的,该空间的维度是
这个核就会将原始空间映射为无穷空间,不过,如果
这个核存在的主要目的是使得“映射后空间中的问题”和“映射前空间中的问题”两者在形式上统一起来。
在选取核函数解决实际问题时,通常采用的方法有两种: 一是利用专家的先验知识预先选定核函数 二是采用 Cross-Validation 方法,即在进行核函数选取时,分别试用不同的核函数,归纳选取误差最小的核函数。
以预测中小企业信用风险为例,选取 2019 年在中小板上市的企业为样本,预测其是否具有违约风险,我们将上市企业中 ST 或 *ST 的企业认为是有违约风险的企业,选取 6 项财务指标进行预测,财务指标分别为:流动比率、速动比率、利润率、ROA、总资产增长率、现金比率。
# 首先要调用需要的模块
import numpy as np #调用numpy模块
import pandas as pd #调用Pandas模块
# 读取数据
X1=pd.ExcelFile('dataset.xlsx') #读取数据
X1.sheet_names
# 设置特征变量
X_Features = pd.read_excel(X1, sheet_names="Sheet1", usecols=[1,2,3,4,5,6])
X_Features.head() #可以查看特征变量,该步骤可以省略
输出的结果为 6 项用于预测的财务指标,仅展示出前 5 行数据。
current ratio | profit margin | quick ratio | ROA | assets growth rate | cash ratio | |
---|---|---|---|---|---|---|
0 | 1.2575 | -444.3600 | 0.9753 | -142.4324 | -75.2146 | 56.3938 |
1 | 0.4772 | -41.0634 | 0.3112 | -5.0946 | 26.3844 | 4.7799 |
2 | 2.1478 | -60.3331 | 1.3147 | -28.1622 | -7.4949 | 75.0185 |
3 | 0.2740 | -24.0082 | 0.1667 | -18.8108 | -43.2743 | 3.8674 |
4 | 0.7362 | -101.3008 | 0.5202 | -36.2078 | -44.9836 | 26.0609 |
X_Features.info() #可以查看特征变量的数据类型,该步骤可以省略
# 设置标签
Y_Response = pd.read_excel(X1, sheet_names="Sheet1", usecols=[7])
Y_Response = Y_Response["risk"].ravel() #设置标签
# 设置训练集和测试集
from sklearn import model_selection
X_Train, X_Test, Y_Train, Y_Test = model_selection.train_test_split(X_Features, Y_Response, test_size=0.3, shuffle=True) #划分出训练集和测试集,test_size表示测试集的比例,这里是30%,该数字可以更改;shuffle=True是将序列的所有元素随机排序
X_Train.head() #查看训练集,该步骤可省略
X_Test.head() #查看测试集,该步骤可省略
# 调用 SVM 算法
from sklearn.svm import SVC #调用Scikit learn中的SVM算法
# 设置参数进行拟合
C = 1e5 #惩罚因子,该数字可以更改
clf = SVC(C=C, kernel='rbf', gamma=20, decision_function_shape='ovr') #kernel='rbf'时(default),为高斯核,gamma值越小,分类界面越连续;gamma值越大,分类界面越“散”,分类效果越好,但有可能会过拟合。decision_function_shape='ovr'时,为one v rest,即一个类别与其他类别进行划分。
clf.fit(X_Train, Y_Train.ravel()) #对训练集进行拟合
# 对测试集数据进行预测
y_pred = clf.predict(X_Test) #对测试集数据进行预测
#查看预测的准确度
from sklearn.metrics import classification_report
print (classification_report(Y_Test, y_pred)) #查看预测的准确率(Accuracy)、精确率(Precision)、召回率(Recall)和F值(F-Measure)等指标
precision recall f1-score support
0 0.98 1.00 0.99 270
1 0.00 0.00 0.00 6
accuracy 0.98 276
macro avg 0.49 0.50 0.49 276
weighted avg 0.96 0.98 0.97 276
从预测结果来看,预测的准确率(Accuracy)为 98%,虽然整体上预测的准确度较高,但是在测试集中有违约风险的 6 家企业却没有预测出来,主要是因为在整体的样本中,有违约风险的企业数量过少,这也说明在进行信用风险预测时,有信用风险的企业样本不能太少。
使用 Stata自带的 1978 Automobile Data 数据,预测汽车类型是“Domestic”还是“Foreign”。在这一例子中,使用2项指标进行预测,分别是 price 和 gear_ratio。 在 Stata 中的 SVM 实现需要安装外部命令 svmachines,具体实现过程如下。
# 在使用Stata 进行SVM实现之前,需要先安装外部命令:svmachines
ssc install svmachines, replace
# 第一步需要做的是区分训练集和测试集
sysuse auto, clear
set seed 9876
generate u = runiform()
sort u
local split = floor(_N/2)
local train = "1/`=`split'-1'"
local test = "`split'/`=_N'"
# 在训练集中利用SVM进行训练:
svmachines foreign price gear_ratio if !missing(rep78) in `train'
# 测试集里进行预测,并统计准确率:
predict P in `test'
generate err = foreign != P in `test'
tab err in `test'
err | Freq. Percent Cum.
------------+-----------------------------------
0 | 28 73.68 73.68
1 | 10 26.32 100.00
------------+-----------------------------------
Total | 38 100.00
以上输出结果中,0 表示预测正确,1 表示预测错误。根据以上输出结果,就可以进一步计算出准确率等指标。在这一例子中,预测的准确率(Accuracy)为:28 ÷ 38 = 73.68%
下面我们使用 Stata 的另一组数据进行 SVM 的 Stata 实现。使用 Stata自带的 nlsw88.dta 数据,该数据包含了 1988 年采集的 2246 个美国妇女的资料。在这一示例中,我们使用美国妇女的工资“wage”去预测“union”是否为工会成员,具体实现过程如下。
# 第一步和上一个例子相同,我们需要做的是区分训练集和测试集
sysuse nlsw88.dta, clear
set seed 9876
generate u = runiform()
sort u
local split = floor(_N/2)
local train = "1/`=`split'-1'"
local test = "`split'/`=_N'"
# 在训练集中利用SVM进行训练:
svmachines union wage if !missing(union) in `train' # if !missing(union)的目的是为了在进行机器学习时忽略缺失的数据项,也可以事先使用drop的功能将有缺失的数据drop掉
# 测试集里进行预测,并统计准确率:
predict P in `test'
generate err = union != P in `test'
tab err in `test'
err | Freq. Percent Cum.
------------+-----------------------------------
0 | 709 63.08 63.08
1 | 415 36.92 100.00
------------+-----------------------------------
Total | 1,124 100.00
以上输出结果中,0 表示预测正确,1 表示预测错误。根据以上输出结果,就可以进一步计算出准确率等指标。在这一例子中,预测的准确率(Accuracy)为:709 ÷ 1124 = 63.08%,相较于上一个例子,准确率有所下降。这主要是因为在这一例子中,我们中主要为了展示 SVM 的实现,因此只选取了 wage 这一个变量对 union 进行预测,大家在做预测时通常会选取更多的变量。
连享会-直播课 上线了!
http://lianxh.duanshu.com
免费公开课:
直击面板数据模型 - 连玉君,时长:1小时40分钟 Stata 33 讲 - 连玉君, 每讲 15 分钟. 部分直播课 课程资料下载 (PPT,dofiles等)
支持回看,所有课程可以随时购买观看。
专题 | 嘉宾 | 直播/回看视频 |
---|---|---|
⭐ 最新专题 ⭐ | DSGE, 因果推断, 空间计量等 | |
⭕ Stata数据清洗 | 游万海 | 直播, 2 小时,已上线 |
研究设计 | 连玉君 | 我的特斯拉-实证研究设计,-幻灯片- |
面板模型 | 连玉君 | 动态面板模型,-幻灯片- |
面板模型 | 连玉君 | 直击面板数据模型 [免费公开课,2小时] |
Note: 部分课程的资料,PPT 等可以前往 连享会-直播课 主页查看,下载。
关于我们
课程, 直播, 视频, 客服, 模型设定, 研究设计, stata, plus, 绘图, 编程, 面板, 论文重现, 可视化, RDD, DID, PSM, 合成控制法
等
连享会小程序:扫一扫,看推文,看视频……
扫码加入连享会微信群,提问交流更方便
✏ 连享会学习群-常见问题解答汇总:
✨ https://gitee.com/arlionn/WD