数据稀缺依然是阻碍机器学习在分子性质预测与设计中发挥效能的核心难题,影响范围涵盖医药、溶剂、高分子材料以及能源载体等多个领域。尽管多任务学习(MTL)能够利用不同性质之间的相关性提升预测性能,但由于训练数据分布不均,常常受到负迁移的干扰,削弱了其效果。

为此,沙特阿卜杜拉国王科技大学的研究团队于2025年7月8日在communications chemistry发表最新成果,提出了一种适用于多任务图神经网络的训练策略ACS(adaptive checkpointing with specialization),该方法在保留MTL优势的同时,减轻了任务间的有害干扰。

研究团队在多个分子性质预测基准任务中对ACS进行了验证,结果显示其性能稳定优于或不逊于当前主流监督方法。为评估其实际应用潜力,研究者将ACS部署于可持续航空燃料性质预测这一现实任务中,结果表明即使仅使用29个带标签样本,也能训练出准确的预测模型。通过在低数据环境中实现可靠的性质预测,ACS显著拓展了AI驱动材料发现与设计的适用边界,并加速了其实际应用进程。

背景

基于机器学习的分子性质预测模型能够通过提供准确的性质预测,显著加速高性能分子和混合物的从头设计。然而,这类模型的效能高度依赖于训练数据的数量与质量,数据匮乏成为制约其广泛应用的关键瓶颈。为缓解这一问题,多任务学习(Multi-task learning, MTL)被提出用以挖掘相关分子性质之间的关联性。借助归纳迁移(inductive transfer),MTL能够利用一个任务的训练信号或学到的表示来提升其他任务的预测性能,从而挖掘共享结构,实现更准确的预测。

然而在实践中,MTL常常受到负迁移(Negative Transfer, NT)的困扰,即某一任务引导的模型更新反而损害了其他任务的性能。现有研究普遍将其归因于任务间相关性较低以及由此引发的共享参数梯度冲突。这些冲突不仅削弱了多任务学习的优势,有时甚至会导致整体性能的下降。除此之外,网络结构不匹配、优化策略不一致以及数据分布差异(如时序或空间异质性)也可能导致负迁移。这些因素往往相互作用,进一步加剧性能退化。因此,缓解策略必须具有一定的鲁棒性,能够同时应对架构、优化和数据分布上的多重不匹配。在许多现实应用中,MTL还面临严重的任务不平衡问题,这通常指某些任务拥有的数据远少于其他任务。在此情形下,数据稀缺任务对共享模型参数的贡献被削弱,进一步加重负迁移效应。尽管已有方法尝试通过插补或完整样本分析来处理缺失标签,但这些方法要么泛化能力有限,要么无法充分利用现有数据。为此,本文采用损失掩码作为更实用的缺失标签应对策略。

针对上述挑战,作者提出了一种适用于多任务图神经网络的高效训练框架ACS(Adaptive Checkpointing with Specialization)。该方法结合了共享的、任务无关的骨干网络任务特定的可训练头部模块,并在训练过程中动态检测负迁移信号,自适应地保存和回溯模型参数。在训练阶段,各任务共享骨干网络以挖掘共性;而在推理阶段,为每个任务生成一个专属模型,从而在保留归纳迁移优势的同时,有效隔离有害的参数干扰。

方法

模型架构

本框架由共享的GNN骨干独立的MLP预测头组成(图1a),每个预测任务对应一个独立的MLP头。GNN骨干由多个边调节卷积层构成,每一层后接一个非线性激活函数。这些边调节层在节点更新过程中融合了化学键信息,使模型能够同时学习原子级和键级的特征。每个任务对应一个六层MLP,采用级联神经元配置(即隐藏层的维度随着深度逐层递减)。这些MLP预测头在利用GNN骨干生成的共享潜在表示的同时,支持任务专属的学习。

图1 ACS, MTL, MTL-GLC以及STL训练机制

分子图构建

使用PyTorch的Geometric库将SMILES分子字符串转换为图结构编码。每个分子被表示为一个图,其中节点代表原子,边表示化学键。表1汇总了用于节点和边的完整特征集合。节点和边特征分别存储到对应的特征矩阵中。通过将这些表示输入到边调节卷积层中,模型能够在一个任务无关的潜在空间中聚合节点与边之间的信息。最终,每个分子的潜在表示被送入对应的任务专属MLP预测头,用于进行最终的性质预测。

表1 分子图中原子和边特征

自适应检查点机制

自适应检查点机制会持续监控每个任务的验证损失,一旦某任务的验证损失相较于其上一个检查点有所下降,便会创建一个新的检查点。在这个时间点,当前的共享骨干参数该任务对应的专属预测头参数将被保存,从而确保每个任务都能保留其已知的最优参数状态。通过将这些检查点隔离开来,ACS能够在后续训练步骤中缓解来自其他任务的负迁移影响。表2展示了该训练过程的伪代码。

表2 自适应检查点训练机制

缺失值的损失掩码机制

为最大程度利用训练数据,在损失函数中实现了一种掩码机制(图2)。该方法在计算预测值与真实值之间的误差时,仅选择有效的数据点进行参与,忽略缺失的目标值。最终损失会根据目标元素的总数进行归一化,从而在不同规模的数据集之间保持鲁棒性。

图2 损失掩码过程

融入不确定性的损失函数

在SAF性质预测任务中,每个训练样本都包含一个对应的不确定性值σ,该值来源于实验记录或基于测量过程的估算。为了在模型训练中融入这些不确定性信息,对每个样本的平方误差项进行加权缩放,其缩放因子为α/(α+σi)。其中,α是一个用户设定的较小的超参数。在SAF数据集中,作者发现设置α=0.1时能达到较好效果。最终的不确定性感知均方误差(MSE_UA)定义如下公式:

该设计在训练过程中对高不确定性样本的误差惩罚更小,从而防止噪声较大的测量值对参数更新产生过度影响。

数据制备与预处理

作者选取了MoleculeNet基准数据集中的一个子集,这些数据集具有多个下游任务,且适合进行参数化评估。最终选定的三个满足条件的数据集为:ClinTox、SIDER 和 Tox21。基于Murcko骨架划分策略将每个数据集按8:1:1的比例划分为训练集、验证集和测试集。此外,系统性地调整了任务不平衡度。具体而言,在CT_TOX这一列中人为引入缺失值,以实现设定的不平衡程度,并仅保留一部分样本以模拟小规模数据集。通过控制这一变量,可以观察从严重不平衡的低数据情形到相对平衡的数据充足情形下,ACS的表现优势何时最为显著、何时可与传统的多任务学习或单任务学习方法相媲美。

为了验证ACS方法在真实场景下的实用价值,将其应用于一个SAF性质预测任务。具体来说,整理了一个包含1379个分子的数据集,这些分子是各种SAF混合燃料中常见的组分。由于SAF数据集中存在较高比例的缺失标签,用5折交叉验证以获得更稳健的模型性能估计,并适配数据的稀疏性。

结果

ACS有效缓解负面迁移

表3展示了对比结果,ACS的表现与现有模型相当甚至超越。仅有D-MPNN实现了接近的稳定性能。虽然两者都采用了消息传递机制,但D-MPNN沿有向边传播消息,以减少冗余更新。总体而言,ACS相较于其他基于节点中心消息传递的模型,实现了平均11.5%的性能提升。

表3 在三个MoleculeNet基准数据集上的测试表现

为了明确ACS的性能提升是否来自其整体架构设计或其缓解NT的能力,对多种基准训练方案进行了评估(图1)。将ACS与不带checkpointing的多任务学习(MTL)带全局损失检查点的多任务学习(MTL-GLC)、以及带checkpointing的单任务学习(STL)进行了比较,后者为每个任务配备了独立的主干和输出头,因此完全取消了参数共享。STL拥有比MTL方法更高的学习容量,因为它没有引入任何参数共享。ACS依然平均优于STL8.3%,这表明归纳迁移所带来的益处十分显著。MTL和MTL-GLC的性能也超过了STL,但优势较小(分别为3.9% 和5.0%)。ACS与其他MTL方法之间更大的差距,突显了其在抑制负迁移方面的有效性。

机制洞察与适用范围

作者使用ClinTox数据集研究了任务不平衡对MTL性能的影响。为实现对任务不平衡程度I的定量评估,使用以下公式定义每个任务的不平衡度:

其中,Li是数据集D中第i个任务的有标签数据数量。该指标反映了某一任务标签的稀缺程度,相对于数据集中标签最多的任务而言。通过调节该比例,可以直接观察当某些任务标签显著较少时,NT如何发生。

图3 所有任务上Clintox数据集的平均性能差异

图3a和3b展示了ACS在不同任务不平衡程度下相对于MTL-GLC和STL的性能提升。总体来看,ACS在高度不平衡的情境下表现出最显著的性能增益,这也与其通过检查点机制缓解NT、保留任务特异性head的设计初衷一致。由此发现ACS在高任务不平衡的情境下仍能表现出色。这类情境的主要问题是,高数据量任务主导了共享参数的更新,导致梯度冲突,从而削弱了低数据量任务的表现。ACS通过在模型参数受到负迁移或随机梯度下降中噪声干扰之前保留一份模型检查点,有效应对了这一问题。通过这种方式,ACS能够利用更具表现力的网络来服务于数据充分的任务,同时保护数据稀缺任务免受模型复杂度过高带来的负面影响。

真实数据应用

为了评估ACS的实际实用性,将其部署于一个真实世界的应用场景中:预测可持续航空燃料(SAF)的关键性质。针对SAF性质的机器学习预测方法,旨在减少实验测试所需的时间与成本。整理了一个包含1381个SAF常见烃类分子的数据库,涵盖了15项对SAF设计和认证至关重要的性质。根据标注样本数量n将数据划分为两种情境:超低数据情境(n 低数据情境(150 

图4 MTL,MTL-GLC,STL和ACS平均性能

图4所示,ACS在所有任务中表现优于这些基线方法,平均提升达12.9%。此外,ACS在5折交叉验证中的表现更为稳定,相比基线方法,平均将变异系数降低了32.7%。

图5 各种训练机制在SAF数据集的性能

ACS的优势在超低数据情境中表现得尤为显著。结果表明ACS能有效缓解负迁移问题,即使在标注数据极少的任务中,也能实现有益的归纳迁移。在低数据情境下,STL的性能与ACS相当,这表明对于样本数量较多的任务,STL能充分利用其专属学习能力。ACS在两种数据情境中都始终展现出稳健的性能,凸显了其在任务不平衡显著且可用数据有限的现实场景中的有效性。

任务相关性在ACS主干网络中表现明显

作者计算了在每个SAF数据集分子上,不同任务的主干网络输出图嵌入之间的余弦相似度。对所有分子结果进行平均,从而为每对任务生成一个相似度值。图6展示了相似度排名前10%的任务对。部分任务对表现出非常高的对齐程度,说明它们捕捉到了共享的物理机制或化学特征。这些发现证实,ACS能够有效地利用归纳迁移,在保留任务间有意义共性的同时,实现任务专属的优化。

图6 任务相似性分析弦图

图7展示了两组高相似任务对的UMAP映射图,颜色表示分子家族。这些在相关任务中一致的、以化学家族为单位的分子投射结果表明,即使在没有显式的化学家族监督信息的情况下,任务专用的骨干网络也倾向于将化学结构相似的分子嵌入到潜在空间中结构类似的区域中。此外,最后一次验证集损失下降所处的训练轮次也支持了这样一个观点,相似的任务在训练过程中往往在相似的时间段获得了积极的参数更新。尽管任务在标签数量和复杂性方面存在差异,但相关任务通常会在训练的相似阶段趋于收敛,反映出它们共享的归纳学习动态。

图7 任务相似性分析UMAP图

训练和推理的计算损失

以实际的SAF分子性质预测场景为例,MTL的训练速度是STL基线的11.5倍,而可训练参数仅为其6%(表4)。这一效率提升源于特定的模型架构设计,其中约43%的总参数位于GNN主干中,而每个预测头仅占据极少的参数量。在多个任务间共享这个复杂的GNN主干极大地降低了MTL的整体训练成本。然而,若采用较小的主干架构,这种训练速度相较STL的优势可能会减弱。

尽管在ACS中进行了更频繁的检查点保存,并未观察到与MTL-GLC相比的训练时间显著增加。这一结果强调了ACS在缓解负迁移时的计算效率,即使需要运行多个训练实例。检查点保存带来的开销使得ACS和MTL-GLC的训练时间比标准MTL长约17%,这凸显了低开销检查点技术的重要性。

值得注意的是,在材料设计相关的分子性质预测任务中,通常需要同时评估多个性能指标,因此像ACS这样的方法特别适用于真实设计问题。通过共享表示,这类多性质预测方法能够加快化学空间的探索,并推动可行分子候选的发现进程。

表4 各种架构在SAF数据上训练的参数及时间

讨论

研究结果表明,ACS在多任务分子性质预测中提供了一种稳健且数据高效的策略,能够有效应对负迁移问题。通过将共享的GNN骨干网络任务特定的头相结合,并在某一任务的验证损失达到新低时选择性地保留参数,ACS在利用共享表示与防止有害更新干扰个别任务之间实现了有效平衡。在多个分子性质预测基准测试中,ACS的表现与领先的监督学习方法相当甚至更优,进一步验证了解决梯度冲突的重要性。ACS的实际应用价值在其用于SAF性质预测任务中得到了充分体现,在这一场景下,ACS显著减少了实现准确建模所需的训练样本数量。尤其是在超低数据量条件下,仅使用29个有标签样本便能获得可靠的预测结果。

展望未来,后续实验将系统性地探索数据特征的差异任务之间关联程度如何分别增强模型协同效应或加剧负迁移。本研究聚焦于数据稀缺与任务不平衡,未来还应考察其他负迁移诱因,如学习率设置不匹配、数据分布差异等。此外,解决检查点机制带来的计算成本问题(如内存与延迟开销)也将提升其可扩展性,使ACS能够扩展到数百个任务,并适用于资源受限平台上的部署。

目前仍不清楚,在一个完全由不重叠任务组成的数据集上,是否仍存在有意义的跨任务学习潜力。在这种极端条件下评估ACS,有助于进一步判断其检查点策略是否比标准的单任务学习方法具有更大优势。最后,引入任务优先级或调度机制,也可能进一步增强ACS对因任务不平衡而产生的参数影响差异的调控能力。与此同时,未来版本的ACS也可以考虑使用元学习或预训练权重来初始化共享骨干网络,这种混合策略可能在标签稀缺条件下加快收敛速度、提高泛化能力,同时保留ACS在处理任务不平衡和稀疏监督方面的核心优势。

参考链接:

https://doi.org/10.1038/s42004-025-01592-1

--------- End ---------

内容中包含的图片若涉及版权问题,请及时与我们联系删除