import torch import torch.nn as nn import numpy as np from PIL import Image import os import random import scipy.stats as st import copy from utils import ROOT_PATH from functools import partial import copy import pickle as pkl from torch.autograd import Variable import torch.nn.functional as F from dataset import params from model import get_model class BaseAttack(object): def __init__(self, attack_name, model_name, target): self.attack_name = attack_name self.model_name = model_name self.target = target if self.target: self.loss_flag = -1 else: self.loss_flag = 1 self.used_params = params(self.model_name) self.model = get_model(self.model_name) self.model.cuda() self.model.eval() def forward(self, *input): raise NotImplementedError def _mul_std_add_mean(self, inps): dtype = inps.dtype mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda() std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda() inps.mul_(std[:,None, None]).add_(mean[:,None,None]) return inps def _sub_mean_div_std(self, inps): dtype = inps.dtype mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda() std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda() inps = (inps - mean[:,None,None])/std[:,None,None] return inps def _save_images(self, inps, filenames, output_dir): unnorm_inps = self._mul_std_add_mean(inps) for i,filename in enumerate(filenames): save_path = os.path.join(output_dir, filename) image = unnorm_inps[i].permute([1,2,0]) image[image<0] = 0 image[image>1] = 1 image = Image.fromarray((image.detach().cpu().numpy()*255).astype(np.uint8)) image.save(save_path) def _update_inps(self, inps, grad, step_size): unnorm_inps = self._mul_std_add_mean(inps.clone().detach()) unnorm_inps = unnorm_inps + step_size * grad.sign() unnorm_inps = torch.clamp(unnorm_inps, min=0, max=1).detach() adv_inps = self._sub_mean_div_std(unnorm_inps) return adv_inps def _update_perts(self, perts, grad, step_size): perts = perts + step_size * grad.sign() perts = torch.clamp(perts, -self.epsilon, self.epsilon) return perts def _return_perts(self, clean_inps, inps): clean_unnorm = self._mul_std_add_mean(clean_inps.clone().detach()) adv_unnorm = self._mul_std_add_mean(inps.clone().detach()) return adv_unnorm - clean_unnorm def __call__(self, *input, **kwargs): images = self.forward(*input, **kwargs) return images class ADG(BaseAttack): def __init__(self, model_name, sample_num_batches=130, steps=10, epsilon=16/255, target=False, decay=1.0): super(ADG, self).__init__('ADG', model_name, target) self.epsilon = epsilon self.steps = steps self.step_size = self.epsilon/self.steps self.decay = decay self.image_size = 224 self.crop_length = 16 self.sample_num_batches = sample_num_batches self.max_num_batches = int((224/16)**2) assert self.sample_num_batches <= self.max_num_batches self._register_model() def _register_model(self): def attn_ADG(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] if self.model_name in ['vit_base_patch16_224', 'visformer_small', 'pit_b_224']: B,C,H,W = grad_in[0].shape out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W) max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1) max_all_H = max_all//H max_all_W = max_all%H min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1) min_all_H = min_all//H min_all_W = min_all%H out_grad[:,range(C),max_all_H,:] = 0.0 out_grad[:,range(C),:,max_all_W] = 0.0 out_grad[:,range(C),min_all_H,:] = 0.0 out_grad[:,range(C),:,min_all_W] = 0.0 if self.model_name in ['cait_s24_224']: B,H,W,C = grad_in[0].shape out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B, H*W, C) max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0) max_all_H = max_all//H max_all_W = max_all%H min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0) min_all_H = min_all//H min_all_W = min_all%H out_grad[:,max_all_H,:,range(C)] = 0.0 out_grad[:,:,max_all_W,range(C)] = 0.0 out_grad[:,min_all_H,:,range(C)] = 0.0 out_grad[:,:,min_all_W,range(C)] = 0.0 return (out_grad, ) def attn_cait_ADG(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] B,H,W,C = grad_in[0].shape out_grad_cpu = out_grad.data.clone().cpu().numpy() max_all = np.argmax(out_grad_cpu[0,:,0,:], axis = 0) min_all = np.argmin(out_grad_cpu[0,:,0,:], axis = 0) out_grad[:,max_all,:,range(C)] = 0.0 out_grad[:,min_all,:,range(C)] = 0.0 return (out_grad, ) def q_ADG(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] out_grad[:] = 0.0 return (out_grad, grad_in[1], grad_in[2]) def v_ADG(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] if self.model_name in ['visformer_small']: B,C,H,W = grad_in[0].shape out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W) max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1) max_all_H = max_all//H max_all_W = max_all%H min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1) min_all_H = min_all//H min_all_W = min_all%H out_grad[:,range(C),max_all_H,max_all_W] = 0.0 out_grad[:,range(C),min_all_H,min_all_W] = 0.0 if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224']: c = grad_in[0].shape[2] out_grad_cpu = out_grad.data.clone().cpu().numpy() max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0) min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0) out_grad[:,max_all,range(c)] = 0.0 out_grad[:,min_all,range(c)] = 0.0 return (out_grad, grad_in[1]) def mlp_ADG(module, grad_in, grad_out, gamma): mask = torch.ones_like(grad_in[0]) * gamma out_grad = mask * grad_in[0][:] if self.model_name in ['visformer_small']: B,C,H,W = grad_in[0].shape out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W) max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1) max_all_H = max_all//H max_all_W = max_all%H min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1) min_all_H = min_all//H min_all_W = min_all%H out_grad[:,range(C),max_all_H,max_all_W] = 0.0 out_grad[:,range(C),min_all_H,min_all_W] = 0.0 if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224', 'resnetv2_101']: c = grad_in[0].shape[2] out_grad_cpu = out_grad.data.clone().cpu().numpy() max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0) min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0) out_grad[:,max_all,range(c)] = 0.0 out_grad[:,min_all,range(c)] = 0.0 for i in range(len(grad_in)): if i == 0: return_dics = (out_grad,) else: return_dics = return_dics + (grad_in[i],) return return_dics attn_ADG_hook = partial(attn_ADG, gamma=0.25) attn_cait_ADG_hook = partial(attn_cait_ADG, gamma=0.25) v_ADG_hook = partial(v_ADG, gamma=0.75) q_ADG_hook = partial(q_ADG, gamma=0.75) mlp_ADG_hook = partial(mlp_ADG, gamma=0.5) if self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']: for i in range(12): self.model.blocks[i].attn.attn_drop.register_backward_hook(attn_ADG_hook) self.model.blocks[i].attn.qkv.register_backward_hook(v_ADG_hook) self.model.blocks[i].mlp.register_backward_hook(mlp_ADG_hook) elif self.model_name == 'pit_b_224': for block_ind in range(13): if block_ind < 3: transformer_ind = 0 used_block_ind = block_ind elif block_ind < 9 and block_ind >= 3: transformer_ind = 1 used_block_ind = block_ind - 3 elif block_ind < 13 and block_ind >= 9: transformer_ind = 2 used_block_ind = block_ind - 9 self.model.transformers[transformer_ind].blocks[used_block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook) self.model.transformers[transformer_ind].blocks[used_block_ind].attn.qkv.register_backward_hook(v_ADG_hook) self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_ADG_hook) elif self.model_name == 'cait_s24_224': for block_ind in range(26): if block_ind < 24: self.model.blocks[block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook) self.model.blocks[block_ind].attn.qkv.register_backward_hook(v_ADG_hook) self.model.blocks[block_ind].mlp.register_backward_hook(mlp_ADG_hook) elif block_ind > 24: self.model.blocks_token_only[block_ind-24].attn.attn_drop.register_backward_hook(attn_cait_ADG_hook) self.model.blocks_token_only[block_ind-24].attn.q.register_backward_hook(q_ADG_hook) self.model.blocks_token_only[block_ind-24].attn.k.register_backward_hook(v_ADG_hook) self.model.blocks_token_only[block_ind-24].attn.v.register_backward_hook(v_ADG_hook) self.model.blocks_token_only[block_ind-24].mlp.register_backward_hook(mlp_ADG_hook) elif self.model_name == 'visformer_small': for block_ind in range(8): if block_ind < 4: self.model.stage2[block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook) self.model.stage2[block_ind].attn.qkv.register_backward_hook(v_ADG_hook) self.model.stage2[block_ind].mlp.register_backward_hook(mlp_ADG_hook) elif block_ind >=4: self.model.stage3[block_ind-4].attn.attn_drop.register_backward_hook(attn_ADG_hook) self.model.stage3[block_ind-4].attn.qkv.register_backward_hook(v_ADG_hook) self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_ADG_hook) def _generate_samples_for_interactions(self, perts, seed): add_noise_mask = torch.zeros_like(perts) grid_num_axis = int(self.image_size/self.crop_length) ids = [i for i in range(self.max_num_batches)] random.seed(seed) random.shuffle(ids) ids = np.array(ids[:self.sample_num_batches]) rows, cols = ids // grid_num_axis, ids % grid_num_axis flag = 0 for r, c in zip(rows, cols): add_noise_mask[:,:,r*self.crop_length:(r+1)*self.crop_length,c*self.crop_length:(c+1)*self.crop_length] = 1 add_perturbation = perts * add_noise_mask return add_perturbation def forward(self, inps, labels): inps = inps.cuda() labels = labels.cuda() loss = nn.CrossEntropyLoss() momentum = torch.zeros_like(inps).cuda() unnorm_inps = self._mul_std_add_mean(inps) perts = torch.zeros_like(unnorm_inps).cuda() perts.requires_grad_() for i in range(self.steps): outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) cost = self.loss_flag * loss(outputs, labels).cuda() cost.backward() grad = perts.grad.data grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True) grad += momentum*self.decay momentum = grad perts.data = self._update_perts(perts.data, grad, self.step_size) perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data perts.grad.data.zero_() return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None