上一篇文章简单介绍了几种集成方法,以随机森林为代表的平均方法,以adaboost、提升树为代表的提升方法。本文继续提升树的讨论,采用梯度提升算法,分析更一般的损失函数下的分类、回归问题。首先介绍几种在机器学习算法中常见的损失函数,继而介绍GBDT算法。
1、 损失函数
1.1、 分类问题损失函数
在二分类中把
m=y⋅f(x)
称为margin。当
m<0 时则说明误分类。基于此,下面介绍几种常见的分类损失函数。
1.1.1、 0-1损失
表达式为:
L01(m)=I(sign(f)≠y)
或者
L01(m)={0,m⩾01,m<0
即只对误分类样本进行惩罚,类似有后列几种损失。
1.1.2、 指数损失
表达式:
Le=exp(−yf)
Adaboost算法就是用的此种损失函数。在上一篇文章中有介绍。
1.1.3、 log损失
表达式:
Ll=log(1+exp(−yf))
在逻辑回归中用到的损失函数。
1.1.4、均方损失
表达式:
L2=(y−f)2
1.1.5、hinge损失
表达式:
Lh=(1−yf)+
在支持向量机中使用。
1.1.6、偏差损失函数
考虑K分类问题,假设在第k类的概率:
pk(x)=exp(fk(x))∑Kl=1exp(fl(x))
则多项式偏差:
L(y,p(x))=−∑k=1KI(y=gk)logpk(x)=−∑k=1KI(y=gk)fk(x)+log(∑k=1Kexp(fl(x)))
在gbdt分类中默认使用此损失函数。
1.2、 回归问题损失函数
下面介绍三种回归问题中常见的的损失函数。
1.2.1、 均方损失
表达式:
L(y,f(x))=(y−f(x))2
1.2.2、绝对损失
表达式:
L(y,f(x))=|y−f(x)|
1.2.3、Huber 损失
表达式:
L(y,f(x))={(y−f(x))2, |y−f(x)|⩽δ2δ(|y−f(x)|−δ2), otherwise
Huber损失相对于前两种损失更具有鲁棒性。
2、 GBDT算法
上篇文章已经介绍,当分类问题是指数损失,回归问题是均方损失时,用提升树算法很简单。当损失函数更为一般时,可以使用梯度提升的方法快速迭代,即为GBDT算法。在R中的包名为gbm,在python中的sklearn.ensemble目录下对应GradientBoostingRegressor和GradientBoostingClassifier。
2.1、 分类问题
GBDT在分类中的损失函数可选偏差损失和指数损失,默认是偏差损失,当选择指数损失时则还原为adaboost算法。前文已经介绍这两种损失,主要介绍偏差损失下K分类算法。
根据偏差损失:
pk(x)=exp(fk(x))∑Kl=1exp(fl(x))
L(y,p(x))=−∑k=1KI(y=gk)logpk(x)=−∑k=1KI(y=gk)fk(x)+log(∑k=1Kexp(fl(x)))
则负梯度:
−gikm=−∂L(yi,f1m(xi),...,fKm(xi))∂fkm(xi)=I(yi=gk)−pk(xi)
梯度提升树K分类算法
1. 初始化 fk0(x)=0,k=1,2,...,K .
2. for (m=1:M):
(a)计算:
pk(x)=exp(fk(x))∑Kl=1exp(fl(x)), k=1,2,...,K
.
(b)for (k=1:K):
i. 计算
rikm=yik−pk(xi), i=1,2,...,N
ii. 对
rikm, i=1,2,...,N 生成一颗
回归树,划分区域
Rjkm, j=1,2,...,Jm
iii. 计算:
γjkm=K−1K⋅∑xi∈Rjkmrikm∑xi∈Rjkm|rikm|(1−|rikm|), j=1,2,...,Jm
iv. 更新:
fkm(x)=fk,m−1(x)+∑j=1JmγjkmI(x∈Rjkm)
.
3. 输出
fk^(x)=fkM(x), k=1,2,...,K .
说明:
2(a)得到第k类的概率,类似多项逻辑回归。
特别地当处理二分类问题时,只需生成一颗回归树。
2.2、 回归问题
当提升树处理损失函数是均方损失的回归问题时,每次迭代拟合当前的残差生成一颗回归树。更一般的情况,根据加法模型
Θ^m=argminΘm∑i=1NL(yi,fm−1(xi)+T(xi;Θm))
因此,若每次新增的树在损失函数负梯度方向上,则能够令损失函数最速下降。
梯度提升树回归算法:
1. 初始化 f0(x)=argminγ∑Ni=1L(Yi,γ) .
2. for (m=1:M):
(a) for(i=1:N),计算
rim=−[∂L(yi,f(xi))∂f(xi)]f=fm−1
(b) 根据
rim 拟合回归树,划分得到区域
Rjm, j=1,2,...,Jm
(c)for (
j=1,2,...,Jm ),计算
γjm=argminγ∑xi∈RjmL(yi,fm−1(xi)+γ)
(d)更新
fm(x)=fm−1(x)+∑Jmj=1γjmI(x∈Rjm) 。
3. 输出
f^(x)=fM(x)
说明:当损失函数为均方损失时,2(c)中 γjm 为 Rjm 中的均值。