掘金 人工智能 前天 19:35
对抗训练:FGM与PGD方法介绍
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入介绍了对抗训练的核心思想和两种常用方法:FGM(Fast Gradient Method)和PGD(Projected Gradient Descent)。文章解释了为何需要对抗训练,即通过引入微小但能欺骗模型的扰动来增强模型对输入变化的鲁棒性。FGM通过一次性沿梯度方向扰动输入来生成对抗样本,而PGD则通过多步迭代扰动,并在局部邻域内搜索更强的对抗样本,以期获得更好的模型性能。文章还探讨了对抗训练的优势,如提升泛化能力、减少过拟合以及提高模型在分布外样本上的稳定性,并指出PGD虽然效果更好但计算成本更高。

🎯 **对抗训练的必要性**:深度学习模型虽然拟合能力强,但在面对输入中的微小扰动(如数据增强或对抗样本)时可能表现不佳。对抗训练通过生成人类难以察觉但能最大化模型错误预测的样本,并将这些样本纳入训练,旨在增强模型的鲁棒性,使其在处理稍有偏差的输入时仍能保持准确性。

⚡ **FGM(Fast Gradient Method)**:FGM是一种快速生成对抗样本的方法。它通过计算损失函数对输入的梯度,然后沿着梯度的单位方向施加一个小的扰动(由超参数ε控制),从而生成一个“对抗样本”。这个过程旨在最大化模型的损失,使模型学习到如何应对这种扰动。在实践中,通常会对梯度进行归一化(如L2范数),以控制扰动的幅度。

🔄 **PGD(Projected Gradient Descent)**:PGD是比FGM更强的对抗训练方法。它在一个局部邻域内进行多步迭代扰动,每一步都试图在当前扰动样本的基础上进一步增加损失。PGD的关键在于其“投影”步骤,它会确保每次迭代产生的扰动不会超出预设的范围(由ε定义),从而更有效地逼近“最坏情况”。PGD通过多次迭代和优化,能够生成更具挑战性的对抗样本,从而训练出更鲁棒的模型。

📈 **PGD的优势与权衡**:PGD相比FGM能够更充分地搜索对抗样本,攻击能力更强,并且有助于模型克服局部最优解,从而获得更好的性能。然而,其多步迭代的特性也带来了显著的计算成本增加,因此对抗训练的次数(K)成为一个需要权衡的关键超参数,需要在模型性能和训练效率之间找到平衡。

🛡️ **对抗训练的深层价值**:对抗训练不仅是提升模型鲁棒性的技术,更是一种有效的正则化方法。它能减少模型过拟合,迫使模型在局部邻域内保持预测的一致性,并提高模型对分布外样本的稳定性。因此,对抗训练是一种提升模型泛化能力、应对现实世界不确定性的重要策略。

对抗训练:FGM与PGD方法介绍

为什么会有对抗训练

当我们对一个事物进行建模或描述时,即使这种描述可能存在偏差,它仍然是对该事物本质的一种映射或表达。对于一张猫的图片来说,我们知道它是猫,但当我们对图片进行裁剪、缩放、遮掩和模糊后(对于模型而言,常见的数据增强(如裁剪、模糊等)是有意识地改变输入,但它们通常不会欺骗模型。而对抗样本则是通过在输入中加入微小且人类难以察觉的扰动,以最大化模型的预测错误,从而检验模型的鲁棒性。这里只是为了方便理解故使用这些),我们可能不太能认出这是猫,但他实际上仍是猫。为了让模型能在这种情况下认出他是猫,所以我们认为构建这种人类难以察觉但会导致模型错误预测的微小扰动样本加入到模型的训练中,能使模型具备辨别的能力,而这种生成困难表示的方式就是对样本进行数据扰动,而在深度学习中,由于深度模型具有极高的拟合能力,它们很容易在训练数据上表现良好,但在面对稍有偏差的输入时却可能表现极差。而训练模型的数据往往本身就是有噪声且不充分的,所以我们需要引入对抗训练,让模型能优化到那些他原本可能薄弱的实例。

数据扰动的几种方式

1. FGM

在FGM中,数据扰动让输入或嵌入在最大化损失的方向上沿梯度单位方向扰动一次(通常要对梯度进行归一化以控制扰动幅度,常见的是使用L2归一,即将梯度除以它的L2范数),使得模型损失最大程度增长,让模型更能学到知识,至于这里的梯度是如何来的,如果你看过使用FGM的代码,你会发现在一个batch中调用了两次model和loss,第一次正是用来计算梯度帮助第二次生成对抗样本的

xadv=x+εxL(x,y)xL(x,y)x_{adv} = x + ε \cdot \frac{\nabla_x L(x, y)}{\|\nabla_x L(x, y)\|}
 class FGM:     def __init__(self, model):         self.model = model         self.backup = {}      def attack(self, epsilon=1.0, emb_name='embed_tokens'): #例子是nlp文本生成的,其中embedding这里会给出每个token的向量,所以对他的参数进行扰动就能实现对样例的扰动         for name, param in self.model.named_parameters():             if param.requires_grad and emb_name in name:                 self.backup[name] = param.data.clone()                 norm = torch.norm(param.grad)                 if norm != 0:                     r_at = epsilon * param.grad / norm                     param.data.add_(r_at)      def restore(self, emb_name='embed_tokens'):         for name, param in self.model.named_parameters():             if param.requires_grad and emb_name in name and name in self.backup:                 param.data = self.backup[name]         self.backup = {}  ''' 在batch中 '''  output=model(**batch) loss=output.loss loss.backward() fgm.attack() #进行数据扰动 output2=model(**batch) loss2=output2.loss loss2.backward() fgm.restore() #恢复扰动数据,因为有些时候的扰动实际上是通过作作用于模型参数做的,如nlp,因为他的原始输入是文字等,无法扰动。所以攻击完后要复原,趁着参数还没更新

2. PGD

与 FGM 只执行一次线性扰动不同,PGD 会在一个局部邻域内多步迭代地进行扰动,每一步都在当前对抗样本基础上进一步最大化损失,并通过投影确保扰动不会超过指定的上限,从而更有效地逼近最坏情况。

第一个样本

x0adv=x+δ0,其中 δ0U(ϵ,ϵ)x_{0}^{adv} = x + δ_0,\quad \text{其中 } δ_0 \sim \mathcal{U}(-\epsilon, \epsilon)

后续样本

xt+1=xt+αsign(xL(xt,y))xt+1=clip(xt+1,xε,x+ε)x_{t+1} = x_t + α * sign(∇_x L(x_t, y))\\ x_{t+1} = clip(x_{t+1}, x-ε, x+ε)

clip 对扰动进行裁剪防止过大或过小

 class PGD:     def __init__(self, model, emb_name='model.encoder.embed_tokens', epsilon=1.0, alpha=0.3, K=3):         self.model = model         self.emb_name = emb_name  # 目标embedding名称,因为例子是nlp文本生成的,其中embedding这里会给出每个token的向量,所以对他的参数进行扰动就能实现对样例的扰动         self.epsilon = epsilon    # 扰动半径         self.alpha = alpha        # 每步步长         self.K = K                # 攻击步数         self.emb_backup = {}      # 原始embedding备份         self.grad_backup = {}      def attack(self, is_first_attack=False):         for name, param in self.model.named_parameters():             if param.requires_grad and self.emb_name in name:                 if is_first_attack:                     self.emb_backup[name] = param.data.clone()                 norm = torch.norm(param.grad)                 if norm != 0 and not torch.isnan(norm):                     r_at = self.alpha * param.grad / norm                     param.data.add_(r_at)                     param.data = self.project(name, param.data)      def restore(self):         for name, param in self.model.named_parameters():             if param.requires_grad and self.emb_name in name:                 assert name in self.emb_backup                 param.data = self.emb_backup[name]         self.emb_backup = {}      def project(self, param_name, param_data):         r = param_data - self.emb_backup[param_name]         if torch.norm(r) > self.epsilon:             r = self.epsilon * r / torch.norm(r)         return self.emb_backup[param_name] + r      def backup_grad(self):         for name, param in self.model.named_parameters():             if param.requires_grad and param.grad is not None:                 self.grad_backup[name] = param.grad.clone()      def restore_grad(self):         for name, param in self.model.named_parameters():             if param.requires_grad and param.grad is not None:                 param.grad = self.grad_backup[name]   ''' 在batch中 ''' outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) loss = outputs.loss epoch_loss += loss.item() loss.backward()  pgd.backup_grad() #保留原梯度,用来后续还原 for t in range(pgd.K):     pgd.attack(is_first_attack=(t==0)) #数据扰动     if t != pgd.K - 1:         model.zero_grad() #清除中间梯度防止污染累计梯度     else:         pgd.restore_grad() #最后一次攻击后恢复梯度     outputs_adv = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])     adv_loss = outputs_adv.loss     adv_loss.backward() #求梯度用于优化 pgd.restore() #恢复扰动数据

PGD多次生成对抗样本并优化的优势

    攻击能力更强搜索更充分克服局部最优解

缺点显而易见,太耗时间了,所以对于PGD而言,对抗次数是一个重要的超参数,需要平衡性能与效果

为什么对抗训练有效

因此,对抗训练不仅是一种提升鲁棒性的技术手段,更是一种提升泛化能力、缓解过拟合的正则化策略,尤其在面对现实世界中非理想输入时,它表现出明显优势。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

对抗训练 FGM PGD 深度学习 模型鲁棒性
相关文章