ICML 2024 | 图结构泛化新视角:从因果干预到不变性学习
1. 图数据泛化为什么你的模型换个场景就“失灵”了大家好我是老吴一个在图机器学习领域摸爬滚打了十来年的“老炮儿”。这些年从社交网络推荐到药物分子发现我亲眼见证了图神经网络GNN从实验室走向工业界的全过程。但有一个问题几乎在每个项目里都会让我和团队头疼不已辛辛苦苦训练好的模型换个数据集或者场景性能就断崖式下跌。比如我们曾用一个在学术引用网络比如Cora、PubMed上表现优异的GNN模型去预测一个企业内部知识图谱中的文档分类结果准确率掉了将近20%。又比如一个在某个特定蛋白质结构数据集上预测结合位点很准的模型换到另一个家族的蛋白质上效果就大打折扣。这背后其实就是我们今天要聊的核心难题——图数据的分布外泛化Out-of-Distribution Generalization, OOD。简单来说OOD问题就是模型在训练时见过的数据分布和它实际应用时遇到的数据分布不一样了。对于图像和文本这种“不一样”可能体现在颜色、风格、词汇上。但对于图数据情况要复杂得多。图不仅包含节点特征更关键的是它那千变万化的拓扑结构——谁和谁相连连接的模式是什么整个图的形状如何。这种结构上的分布偏移是图OOD问题最棘手的地方。想象一下你教一个孩子识别“猫”只给他看各种品种的宠物猫照片训练分布。然后突然让他去野外识别一只猞猁测试分布他很可能认不出来。因为猞猁虽然也是猫科但耳朵有簇毛、体型更大、生活环境完全不同。在图数据里这种“品种”和“环境”的变化就体现在节点连接模式、社区结构、甚至整个图的生成机制上。ICML 2024等顶会上最新的研究正试图从两个深刻的视角来解决这个问题不变性学习和因果干预。它们一个像在寻找“猫”这个物种永恒不变的本质特征比如胡须、肉垫另一个则像在探究“为什么它是猫”背后的因果机制。下面我就结合自己的实战经验带大家深入浅出地看看这两条路怎么走以及它们如何殊途同归。2. 第一性原理基于不变性学习的“以不变应万变”当我们面对一个多变的世界时最可靠的策略是什么答案是抓住那些不变的东西。这就是不变性学习Invariant Learning的核心哲学。它的想法很直观在纷繁复杂的数据背后总存在一些输入X和输出Y之间稳定、可靠的关系这些关系不随环境E的变化而改变。模型只要学会了这些“不变特征”就能在任何新环境中稳健预测。2.1 图上的“不变特征”到底指什么把不变性原理套用到图上需要一些巧思。因为图数据的基本单元是节点和边我们不能简单地把每个节点特征单独拎出来看。2022年ICLR的一篇经典工作《Handling Distribution Shifts on Graphs: An Invariance Perspective》给出了一个非常漂亮的定义。它提出以图中每个节点为中心观察它的局部子图。在这个子图里所有邻居节点的特征信息都会汇聚到中心节点共同决定它的标签。这些信息中有一部分是“不变特征”它们对标签的贡献在任何环境下都稳定另一部分是“冗余特征”或“虚假相关性”它们只在特定环境下与标签强相关。我举个接地气的例子。假设我们要预测一个社交平台上用户的职业标签Y。用户的特征X可能包括教育背景X1、发文活跃度X2。环境E可以是不同的平台比如学术型的ResearchGate和娱乐型的微博。不变特征X1用户的教育背景如博士学历与“研究员”这个职业的关系无论在哪个平台都是强相关的。这是一种本质联系。冗余特征X2在ResearchGate上发文极度活跃的用户很可能就是研究员但在微博上活跃度高的可能是营销号或网红。因此“高活跃度”与“研究员”的关联是平台依赖的是虚假相关性。一个鲁棒的模型应该学会依赖“教育背景”这个不变特征来做判断而不是被“活跃度”带偏。这篇论文的精妙之处在于它通过理论推导将这种直觉形式化证明了在图结构下这种以子图为单位的不变性假设是合理且可学习的。2.2 实战算法EERM——自己给自己“出难题”道理懂了但怎么让模型学会抓住不变特征呢最大的挑战是我们通常没有明确的环境标签E。我们有一大堆混合了各种环境的数据但不知道每个样本具体属于哪个“平台”或“场景”。那篇ICLR论文提出的探索-外推风险最小化EERM算法提供了一个非常聪明的解决方案。它的核心思想是既然没有现成的不同环境我们就自己动手创造出一些“虚拟环境”来训练模型。具体怎么创造算法会引入K个“环境生成器”。这些生成器的任务不是生成逼真的新数据而是对原始训练数据进行各种“扰动”或“增强”从而模拟出数据可能在不同环境下呈现的多样化形态。在优化过程中算法有两个目标内层目标探索让这K个生成器产生的数据尽可能“不同”最大化它们之间的分布差异。这就好比给学生出K套考点侧重点完全不同的模拟卷。外层目标外推在这K套不同的“模拟卷”上训练主预测模型但要求模型在所有卷子上的表现都稳定。具体做法是最小化模型在K个环境上的平均损失同时还要最小化这些损失之间的方差。这意味着模型不能只在某套卷子上考高分必须在所有“刁难”它的卷子上都表现稳健。通过这种“自己创造困难并克服它”的方式EERM逼迫模型去挖掘那些在所有虚拟环境下都通用的预测规律也就是我们想要的不变特征。我在一些涉及跨领域图分类的项目中尝试过这个思路的变体比如让模型同时学习从分子图和社交网络图中提取通用结构模式效果确实比直接混合训练要稳定不少。3. 直击本质基于因果干预的“斩断混淆之手”不变性学习很棒但它有一个很强的预设数据中必须存在我们假设的那种“不变关系”。如果这个假设不成立呢或者我们想追求更本质、更强大的泛化能力该怎么办这时我们需要换一个更基础的视角——因果性。因果干预Causal Intervention方法不假设不变特征的存在它的目标更宏大直接让模型学会从输入X到输出Y的因果机制。为什么因果机制就能泛化因为因果关系是事物间最本质、最稳定的联系它不会因为环境变化而改变。3.1 图学习中的“因果陷阱”要理解因果干预先得看清传统图学习掉进了什么“坑”。我们可以用一张简单的因果图来表示训练好的GNN模型环境(E) - 输入图(X) 环境(E) - 预测标签(Y_hat) 输入图(X) - 预测标签(Y_hat)这里环境E是一个混淆因子。它同时影响了我们观察到的图数据X比如微博这个环境产生了特定的用户连接模式和我们需要预测的标签Y比如该环境下“网红”这个职业更普遍。当我们用标准方法训练模型拟合P(Y|X)时模型实际上学到的是关联关系它混杂了X-Y的真实因果以及E通过X和Y产生的虚假路径E-X-Y和E-Y。结果就是模型把环境带来的虚假相关性也当成了预测依据。一旦环境改变从微博换到LinkedIn这些相关性失效模型就崩了。这就像那个经典例子通过“冰淇淋销量”和“溺水人数”的正相关来预测溺水是荒谬的因为它们背后共同的因果是“夏天”。模型需要学会的是“游泳技能”与“溺水”的因果关系而不是“冰淇淋”这个混淆因子。3.2 因果干预怎么做一个“做手术”的思想实验因果科学给我们提供了一把手术刀do-算子。我们想知道的不是“看到图X时Y是什么”而是“如果我们干预do图X让它强制变成某个样子那么Y会是什么”。这相当于在因果图上做手术切断了环境E对输入X的影响线。数学上我们的目标从最大化似然P(Y|X)变成了最大化干预后的似然P(Y|do(X))。通过因果推断中的后门调整公式这个难以直接计算的目标可以转化为对所有可能环境e求期望Σ_e P(Y|X, Ee) P(Ee)。意思是模型应该学会在所有可能环境下的平均表现。但问题又来了环境E通常是未知的今年WWW 2024的一篇工作《Graph Out-of-Distribution Generalization via Causal Intervention》提出了一个巧妙的工程解决方案——变分环境调整。3.3 实战算法变分推断“脑补”环境既然真实环境未知我们就用一个神经网络环境推断器来“脑补”它。这个推断器会根据输入图X猜测它可能来自哪个潜在环境。整个训练过程变成一个三者的博弈环境推断器努力从数据中识别出有区分度的环境模式。GNN预测器努力基于X和推断器提供的环境信息做出准确的预测。一个先验约束防止环境推断器“偷懒”比如把所有样本都归为同一环境通常要求推断出的环境分布尽量均衡、有区分度。通过联合优化这三个部分模型在某种程度上实现了“后门调整”预测器被迫去学习那些在不同“脑补”环境下都稳健的特征因为这些特征才更可能反映X-Y的因果。这个方法CaNet的美妙之处在于它是模型无关的可以套用在GCN、GAT等各种GNN骨架上。在实际尝试中我发现它在处理那些由时间演变如引文网络按年份划分或空间隔离如不同地区的社交子图引起的分布偏移时尤其有效。4. 融合与进阶当图结构本身也“不确定”时前面讨论的方法无论是EERM还是CaNet都默认我们手中的图结构邻接矩阵是清晰、确定、完全观测的。但现实往往更骨感。很多数据比如一组蛋白质序列、一堆用户行为日志它们之间的“关系”并非显而易见需要我们自己去推断或构建。这种隐式图结构带来了OOD问题的又一重挑战分布偏移不仅来自特征还可能来自我们构建图的方式比如用KNN选不同的K值甚至来自底层数据生成机制的根本变化。今年ICML 2024的这篇《Learning Divergence Fields for Shift-Robust Graph Representations》就从更基础的视角切入将消息传递GNN的核心与物理中的扩散过程联系起来为我们提供了一个统一看待显式和隐式图结构的框架。4.1 把GNN层看作“热扩散”想象一下把图中的节点看作金属板上的点节点的特征表示是温度。GNN每一层的消息传递就像热量从高温点流向低温点的过程。在物理学中这个过程由扩散方程精确描述。论文指出GNN和Transformer的层更新公式其实就是这个扩散方程在离散时间和空间上的近似。其中控制热量如何流动的关键参数是扩散系数。在传统GNN中这个系数是固定的比如归一化的邻接矩阵在Transformer中它是通过注意力机制动态计算的。但无论固定还是动态在标准训练下这个系数都是和训练数据分布紧密耦合的——它学会了训练集里特定的“热量流动模式”。4.2 因果干预扩散过程让流动模式“去偏”问题来了。如果测试数据来自不同的“材料”分布不同热量流动的模式变了用旧模式训练的扩散模型GNN自然就失效了。你看这里的“扩散系数”扮演了之前“环境E”的角色它混淆了输入和输出。解决方案依然是因果干预。论文提出在扩散过程的每一步对应GNN的每一层都对扩散系数进行干预切断它和输入数据的依赖。同样由于扩散系数是隐变量他们采用了更复杂的变分推断技巧来近似这个干预目标。这带来了极大的灵活性。他们基于此框架实现了三个模型GLIND-GCN扩散系数是常数矩阵退化为处理显式结构的经典GCN但训练方式更鲁棒。GLIND-GAT扩散系数是随时间层数变化的用注意力机制学习适用于显式图但关系权重可变的场景。GLIND-Trans扩散系数是全局的、时变的采用高效的线性注意力专门为隐式图结构比如一组需要推断关系的样本设计。这个工作的价值在于它把OOD泛化的问题提升到了“动力学系统”的层面。在我们处理一些结构模糊的生物序列数据时这种思想特别有用。我们不再纠结于一个“绝对正确”的图应该长什么样而是让模型学会在多种可能的结构假设下都能捕捉稳定的预测规律。5. 对比、选择与实战心法聊了这么多理论大家可能有点晕。简单总结一下这两大流派的核心区别特性基于不变性的方法 (如EERM)基于因果干预的方法 (如CaNet, GLIND)核心思想寻找跨环境稳定的统计关联不变特征学习输入到输出的因果机制关键假设数据中存在可分离的、环境不变的特征存在单一的、可识别的混淆因子环境需要环境标签不需要通过生成或推断虚拟环境不需要通过变分推断潜在环境优势直觉清晰在满足不变性假设的场景下理论保障强目标更本质对数据生成机制假设相对更弱挑战不变性假设可能过强虚拟环境的生成质量影响大变分推断可能陷入平凡解对复杂混淆多环境、未观测混淆处理难那么在实际项目中该怎么选呢根据我的经验可以遵循以下思路先分析你的数据偏移来源如果偏移主要来自节点/边特征的虚假相关性比如不同平台上的用户行为模式差异而不变特征确实存在如用户的基础属性那么不变性学习方法可能更直接有效。你可以尝试EERM或其变种重点设计能够分离特征的环境生成器。如果偏移来源更综合、更根本可能涉及图结构的生成机制变化比如不同社交网络的形成规则不同或者你怀疑有强烈的混淆因子那么因果干预方法可能潜力更大。CaNet是一个不错的起点。考虑结构的明确性如果你的图结构是清晰、固定的如已知的分子键、固定的知识图谱上述两种方法都可以尝试。如果你的图结构是隐式、需要构建或不确定的如用KNN从特征构建的图或关系不断变化的时序图那么像GLIND这类从扩散过程出发、能建模结构不确定性的框架可能更合适。一个实用的融合策略 在实际中我常常不会死磕某一种方法。一个有效的策略是将因果干预的思想作为正则化项融入不变性学习的框架。例如在训练时除了EERM的损失可以额外添加一个约束鼓励模型学到的表征与变分推断出的环境变量尽可能独立类似去相关。这种混合损失往往能取得比单一方法更稳定的效果。永远不要忘记数据本身 再好的算法也离不开对数据的深刻理解。在应用这些高级方法前花时间做探索性数据分析至关重要。可视化你的图看看训练集和测试集在节点度分布、聚类系数、社区结构上有什么不同。这些直观的洞察能帮你判断该用哪种方法甚至启发你设计更贴合问题的环境变量或数据增强策略。最后我想说图数据的OOD泛化是一个远未解决的问题不变性学习和因果干预为我们打开了两扇充满希望的门。但它们都不是银弹。真实世界的数据复杂多变混淆因子可能不止一个不变特征可能难以分离。我在工程实践中最大的体会是理论给我们方向但最终解决问题要靠对业务的深入理解、大量的实验迭代以及谨慎的评估。不要指望用一个现成的包就能解决所有泛化问题多看看数据多设计几个对照实验理解模型在什么情况下会失败往往比直接调参更有价值。这条路还很长但每解决一个实际的泛化问题带来的价值也是巨大的。希望这些分享能帮你少踩一些坑。