CW2算法源码解析

CW2算法源码解析该公式左半部分含义是尽量降低对抗样本与原始图像的 L2 距离 右半部分是为了达到误分类的目的

大家好,欢迎来到IT知识分享网。


解析

CW这个名字来源于论文的两个作者的名字Carlini和Wagner的首字母,CW2是论文中的L2范式的攻击方法,这是一种基于优化的攻击方法,目标函数为: m i n i m i z e ∥ 1 2 ( t a n h ( w ) + 1 ) − x ∥ 2 2 + c ⋅ f ( 1 2 ( t a n h ( w ) + 1 ) ) minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1)) minimize21(tanh(w)+1)x22+cf(21(tanh(w)+1))其中 t a n h tanh tanh表示双曲正切函数,本方法使用 1 2 ( t a n h ( w ) + 1 ) \frac{1}{2}(tanh(w)+1) 21(tanh(w)+1)来表示增加扰动后的图像,这样做可以保证扰动后的图像的范围在 [ 0 , 1 ] [0,1] [0,1]中,同时其处处可导数,有利于优化。在代码中,该方法先将图像输入 t a n h tanh tanh的反函数 a t a n h atanh atanh得到初始的 w w w,也就是优化的起点。文中也提到可以选择多个与初始起点相近随机起点来避免陷入局部最小值。公式中的 c c c表示 f f f的值在目标函数中的权重,文中使用二分法来确定在成功产生对抗样本的情况下的最小的 c c c的值。该公式左半部分含义是尽量降低对抗样本与原始图像的L2距离,右半部分是为了达到误分类的目的。
f f f的定义为: f ( x ′ ) = m a x ( m a x { Z ( x ′ ) i : i ≠ t } − Z ( x ′ ) t , − κ ) f(x’)=max(max\{Z(x’)_i:i\neq t\} -Z(x’)_t, – \kappa) f(x)=max(max{
Z(x)i:
i=t}Z(x)t,κ)

只有当 f ( x ′ ) ≤ 0 f(x’)\le0 f(x)0时,才说明 x ′ x’ x被误分类为了目标标签 t t t Z ( x ′ ) Z(x’) Z(x)表示 x ′ x’ x在模型中的输出值logits。通过改变 κ \kappa κ的值,可以控制对抗样本 x ′ x’ x对于目标标签的置信度, κ \kappa κ越大,则最终的置信度越高。

源码

相关代码的解析我写在了注释中。

 import torch import torch.nn as nn import torch.optim as optim from ..attack import Attack class CW(Attack): r""" CW in the paper 'Towards Evaluating the Robustness of Neural Networks' [https://arxiv.org/abs/1608.04644] Distance Measure : L2 Arguments: model (nn.Module): model to attack. c (float): c in the paper. parameter for box-constraint. (Default: 1) :math:`minimize \Vert\frac{1}{2}(tanh(w)+1)-x\Vert^2_2+c\cdot f(\frac{1}{2}(tanh(w)+1))` kappa (float): kappa (also written as 'confidence') in the paper. (Default: 0) :math:`f(x')=max(max\{Z(x')_i:i\neq t\} -Z(x')_t, - \kappa)` steps (int): number of steps. (Default: 50) lr (float): learning rate of the Adam optimizer. (Default: 0.01) .. warning:: With default c, you can't easily get adversarial images. Set higher c like 1. Shape: - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. - output: :math:`(N, C, H, W)`. Examples:: >>> attack = torchattacks.CW(model, c=1, kappa=0, steps=50, lr=0.01) >>> adv_images = attack(images, labels) .. note:: Binary search for c is NOT IMPLEMENTED methods in the paper due to time consuming. """ def __init__(self, model, c=1, kappa=0, steps=50, lr=0.01): super().__init__("CW", model) self.c = c # 由于二分搜索c的值时间开销太大,所以该代码直接定义了c的值 self.kappa = kappa # 定义公式中的kappa的值 self.steps = steps # 迭代次数 self.lr = lr # 学习率,用于优化器优化w的值 # default为无目标攻击,targeted为有目标攻击,论文中的L2攻击是有目标攻击 # 该代码也可以进行无目标攻击 self.supported_mode = ['default', 'targeted'] def forward(self, images, labels): r""" Overridden. """ self._check_inputs(images) images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) if self.targeted: # 得到目标标签 target_labels = self.get_target_label(images, labels) # w = torch.zeros_like(images).detach() # Requires 2x times w = self.inverse_tanh_space(images).detach() # 通过atanh得到初始w w.requires_grad = True # 初始化最佳对抗样本以及最佳L2距离 best_adv_images = images.clone().detach() best_L2 = 1e10*torch.ones((len(images))).to(self.device) prev_cost = 1e10 dim = len(images.shape) # 该损失函数用来计算公式中的第二范式距离的平方 MSELoss = nn.MSELoss(reduction='none') Flatten = nn.Flatten() # 使用Adam优化器 optimizer = optim.Adam([w], lr=self.lr) for step in range(self.steps): # Get adversarial images adv_images = self.tanh_space(w) # 计算L2距离,也就是公式左半部分 current_L2 = MSELoss(Flatten(adv_images), Flatten(images)).sum(dim=1) L2_loss = current_L2.sum() outputs = self.get_logits(adv_images) # 计算f的值,也就是公式右半部分 if self.targeted: f_loss = self.f(outputs, target_labels).sum() else: f_loss = self.f(outputs, labels).sum() # cost即为公式所得值 cost = L2_loss + self.c*f_loss # 使用Adam优化器进行优化 optimizer.zero_grad() cost.backward() optimizer.step() # Update adversarial images pre = torch.argmax(outputs.detach(), 1) if self.targeted: # 找到成功pre == target_labels的样本 condition = (pre == target_labels).float() else: # 如果是无目标攻击,则找到成功误分类的样本 condition = (pre != labels).float() # 找到损失下降并且condition为1的样本,  # 也就是说只有同时成功误分类并且损失比以前最好的损失还低的图片才会被留下 mask = condition*(best_L2 > current_L2.detach()) best_L2 = mask*current_L2.detach() + (1-mask)*best_L2 mask = mask.view([-1]+[1]*(dim-1)) best_adv_images = mask*adv_images.detach() + (1-mask)*best_adv_images # 如果损失不再下降,就提前停止 # max(.,1)为了防止除数为0 if step % max(self.steps//10,1) == 0: if cost.item() > prev_cost: return best_adv_images prev_cost = cost.item() return best_adv_images # 计算tanh def tanh_space(self, x): return 1/2*(torch.tanh(x) + 1) # 计算atanh def inverse_tanh_space(self, x): # torch.atanh is only for torch >= 1.7.0 # atanh is defined in the range -1 to 1 return self.atanh(torch.clamp(x*2-1, min=-1, max=1)) def atanh(self, x): return 0.5*torch.log((1+x)/(1-x)) # f函数 def f(self, outputs, labels): one_hot_labels = torch.eye(outputs.shape[1]).to(self.device)[labels] other = torch.max((1-one_hot_labels)*outputs, dim=1)[0] # 得到除目标标签外的最高的logit real = torch.max(one_hot_labels*outputs, dim=1)[0] # 得到目标标签的logit if self.targeted: return torch.clamp((other-real), min=-self.kappa) else: # 如果是无目标攻击,那么应该增加到真实标签的距离,所以与上式相反 return torch.clamp((real-other), min=-self.kappa) 

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/158246.html

(0)
上一篇 2025-01-25 16:05
下一篇 2025-01-25 16:10

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信