GNN 2021(三) Boost then Convolve: Gradient Boosting Meets Graph Neural Networks,ICLR

文章目录
-
GBDT
-
- 提升树
-
Method
-
实验
如前所述,在这篇论文中,我们将梯度提升决策树(GBDT)与图神经网络融合,并开发出一个完整的系统架构来处理具有固定属性维度的结构化数据集。
GBDT
论文里阐述了梯度提升方法的核心思想:通过逐步累加弱学习器(通常为决策树)来构建一个强学习器。具体而言,在梯度增强算法的每一次迭代t中,在给定训练集上不断累加弱学习器f_t(x)到当前模型f(x)中,并采用加性更新的方式逐步优化预测效果。

在上一个迭代阶段,我们已经生成了一个模型f^{t-1}。第t次迭代中的基函数h^t是弱学习器,在此过程中其学习速率由参数\epsilon控制。每个弱学习器的目标是通过最小化损失函数L的方向梯度来逐步改进预测性能。

实际上,在这篇论文中对提升树的相关内容表述不够清晰。基于此认识,我决定重新梳理这部分与提升树相关的知识点。
提升树

提升树的m步为:

对于平方误差,其前向推导为:
Loss=\sum_{M}L(y_i, f_m(x_i))=\sum_{M}(y_i-f_m(x_i))^2=\sum_{M}(y_i-f_{m-1}(x_i)-T_m(x,θ_m))^2
然后通过最小化上述损失函数求取最优的参数θ。那么,r_{mi}=y_i-f_{m-1}(x_i)就是好多帖子里所提到的残差。但是,如果损失函数不是平方误差,那么上述残差推导的式子就不成立了,为了优化模型就必须找一个学习最快最好的方向,那就是负梯度。为啥负梯度最优?上推导:
首先,我们还是按照刚刚的推导去假定损失函数:L(y_i, f_m(x_i)),f_m可以由f_{m-1}导出,因此损失函数可以写成:
L(y_i, f_{m-1}(x_i)+T_m(θ_m))
这里,把f_{m-1}(x_i)+T_m(θ_m)整体视为一个变量,求X=f_{m-1}(x_i)+T_m(θ_m)在f_{m-1}(x_i)处的泰勒展开式:
Taylor(L)=L(f_{m-1}(x_i))+L'(f_{m-1}(x_i))(X-f_{m-1}(x_i))
=L(y_i,f_{m-1}(x_i))+\frac{\Delta L(y_i, f_{m-1}(x_i))}{\Delta f_{m-1}(x_i)}T_m(θ_m)
公式中的前一项是上一步的损失函数,那么为了让损失减少,就必须保证后一项是负数,也就是\frac{\Delta L(y_i, f_{m-1}(x_i))}{\Delta f_{m-1}(x_i)}T_m(θ_m)<0,最为便捷快速的方法就是使得T_m(θ_m)=-\frac{\Delta L(y_i, f_{m-1}(x_i))}{\Delta f_{m-1}(x_i)}。当损失函数是均方损失时,负梯度刚好是残差,残差只是特例,也就是论文给出的公式。换一句人话:GBDT并不是用负梯度代替残差!!!GBDT建树时拟合的是负梯度!但GBDT使用的回归树经常使用平方误差作为划分准则,所以此时拟合的目标值可以看成是残差的形式。 (参照知乎回答)
然后在这里推导过程可以见哔哩哔哩,讲课的是一个声音好听的小姐姐。(▽)
Method

模型的核心模块致力于探索一种全新的联合训练机制。尽管GBDT与GNN在优化目标上存在本质差异:GNN通过梯度下降优化其参数θ(θ位于可微空间中),而GBDT则以迭代方式构建决策树(决策树基于硬分割特性导致不可微)。
如图所示,在初始迭代阶段构建了包含k棵决策树的GBDT集成模型f¹(x),并通过损失函数L(f¹(x), y)进行参数优化(即算法中所述的目标函数为Y)。随后利用经过优化的GBDT对节点特征进行变换得到X';接着通过GNN模块gθ完成特征传播并最小化其损失函数L(gθ(G,X'), Y);最后通过最小化两模型输出之间的差异来更新节点表示:X'_new = X' - ∇_X' L(gθ(G,X'), Y)。(当l=1时)完全等于损失函数关于X'的负梯度方向上的更新步长:ΔX' = -∇_X'L(gθ(G,X'), Y))

随后训练第二个弱学习器f^2(即第二个决策树),并将其目标替换为X_{new}'-X'。从直观上看,在完成f^2的训练后,将两个弱学习器的输出相加得到新的输入特征:f(X) = f^1(x) + f^2(x)。换句话说,在迭代过程中,我们不仅构建了一个集成模型f: X \rightarrow Y,还发展了一种将图神经网络用于梯度优化的方法。
实验
数据集:

这些数据集之前未曾接触过;它们都属于回归分析范畴。其中House类任务主要聚焦于用于预测房价;County类任务则致力于预测区域失业率;VK类任务通常用于估计网络节点的时间属性。此外,请参见论文附录中的详细讨论。

也做了节点分类,但是并不是常见的分类数据集:


gap是相对差值,越小越好。

训练时间:

训练过程中损失值的变化:

可视化结果:

