xiaolin_code/methods.py

293 lines
13 KiB
Python

import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import random
import scipy.stats as st
import copy
from utils import ROOT_PATH
from functools import partial
import copy
import pickle as pkl
from torch.autograd import Variable
import torch.nn.functional as F
from dataset import params
from model import get_model
class BaseAttack(object):
def __init__(self, attack_name, model_name, target):
self.attack_name = attack_name
self.model_name = model_name
self.target = target
if self.target:
self.loss_flag = -1
else:
self.loss_flag = 1
self.used_params = params(self.model_name)
self.model = get_model(self.model_name)
self.model.cuda()
self.model.eval()
def forward(self, *input):
raise NotImplementedError
def _mul_std_add_mean(self, inps):
dtype = inps.dtype
mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
inps.mul_(std[:,None, None]).add_(mean[:,None,None])
return inps
def _sub_mean_div_std(self, inps):
dtype = inps.dtype
mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
inps = (inps - mean[:,None,None])/std[:,None,None]
return inps
def _save_images(self, inps, filenames, output_dir):
unnorm_inps = self._mul_std_add_mean(inps)
for i,filename in enumerate(filenames):
save_path = os.path.join(output_dir, filename)
image = unnorm_inps[i].permute([1,2,0])
image[image<0] = 0
image[image>1] = 1
image = Image.fromarray((image.detach().cpu().numpy()*255).astype(np.uint8))
image.save(save_path)
def _update_inps(self, inps, grad, step_size):
unnorm_inps = self._mul_std_add_mean(inps.clone().detach())
unnorm_inps = unnorm_inps + step_size * grad.sign()
unnorm_inps = torch.clamp(unnorm_inps, min=0, max=1).detach()
adv_inps = self._sub_mean_div_std(unnorm_inps)
return adv_inps
def _update_perts(self, perts, grad, step_size):
perts = perts + step_size * grad.sign()
perts = torch.clamp(perts, -self.epsilon, self.epsilon)
return perts
def _return_perts(self, clean_inps, inps):
clean_unnorm = self._mul_std_add_mean(clean_inps.clone().detach())
adv_unnorm = self._mul_std_add_mean(inps.clone().detach())
return adv_unnorm - clean_unnorm
def __call__(self, *input, **kwargs):
images = self.forward(*input, **kwargs)
return images
class ADG(BaseAttack):
def __init__(self, model_name, sample_num_batches=130, steps=10, epsilon=16/255, target=False, decay=1.0):
super(ADG, self).__init__('ADG', model_name, target)
self.epsilon = epsilon
self.steps = steps
self.step_size = self.epsilon/self.steps
self.decay = decay
self.image_size = 224
self.crop_length = 16
self.sample_num_batches = sample_num_batches
self.max_num_batches = int((224/16)**2)
assert self.sample_num_batches <= self.max_num_batches
self._register_model()
def _register_model(self):
def attn_ADG(module, grad_in, grad_out, gamma):
mask = torch.ones_like(grad_in[0]) * gamma
out_grad = mask * grad_in[0][:]
if self.model_name in ['vit_base_patch16_224', 'visformer_small', 'pit_b_224']:
B,C,H,W = grad_in[0].shape
out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
max_all_H = max_all//H
max_all_W = max_all%H
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
min_all_H = min_all//H
min_all_W = min_all%H
out_grad[:,range(C),max_all_H,:] = 0.0
out_grad[:,range(C),:,max_all_W] = 0.0
out_grad[:,range(C),min_all_H,:] = 0.0
out_grad[:,range(C),:,min_all_W] = 0.0
if self.model_name in ['cait_s24_224']:
B,H,W,C = grad_in[0].shape
out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B, H*W, C)
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
max_all_H = max_all//H
max_all_W = max_all%H
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
min_all_H = min_all//H
min_all_W = min_all%H
out_grad[:,max_all_H,:,range(C)] = 0.0
out_grad[:,:,max_all_W,range(C)] = 0.0
out_grad[:,min_all_H,:,range(C)] = 0.0
out_grad[:,:,min_all_W,range(C)] = 0.0
return (out_grad, )
def attn_cait_ADG(module, grad_in, grad_out, gamma):
mask = torch.ones_like(grad_in[0]) * gamma
out_grad = mask * grad_in[0][:]
B,H,W,C = grad_in[0].shape
out_grad_cpu = out_grad.data.clone().cpu().numpy()
max_all = np.argmax(out_grad_cpu[0,:,0,:], axis = 0)
min_all = np.argmin(out_grad_cpu[0,:,0,:], axis = 0)
out_grad[:,max_all,:,range(C)] = 0.0
out_grad[:,min_all,:,range(C)] = 0.0
return (out_grad, )
def q_ADG(module, grad_in, grad_out, gamma):
mask = torch.ones_like(grad_in[0]) * gamma
out_grad = mask * grad_in[0][:]
out_grad[:] = 0.0
return (out_grad, grad_in[1], grad_in[2])
def v_ADG(module, grad_in, grad_out, gamma):
mask = torch.ones_like(grad_in[0]) * gamma
out_grad = mask * grad_in[0][:]
if self.model_name in ['visformer_small']:
B,C,H,W = grad_in[0].shape
out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
max_all_H = max_all//H
max_all_W = max_all%H
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
min_all_H = min_all//H
min_all_W = min_all%H
out_grad[:,range(C),max_all_H,max_all_W] = 0.0
out_grad[:,range(C),min_all_H,min_all_W] = 0.0
if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224']:
c = grad_in[0].shape[2]
out_grad_cpu = out_grad.data.clone().cpu().numpy()
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
out_grad[:,max_all,range(c)] = 0.0
out_grad[:,min_all,range(c)] = 0.0
return (out_grad, grad_in[1])
def mlp_ADG(module, grad_in, grad_out, gamma):
mask = torch.ones_like(grad_in[0]) * gamma
out_grad = mask * grad_in[0][:]
if self.model_name in ['visformer_small']:
B,C,H,W = grad_in[0].shape
out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
max_all_H = max_all//H
max_all_W = max_all%H
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
min_all_H = min_all//H
min_all_W = min_all%H
out_grad[:,range(C),max_all_H,max_all_W] = 0.0
out_grad[:,range(C),min_all_H,min_all_W] = 0.0
if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224', 'resnetv2_101']:
c = grad_in[0].shape[2]
out_grad_cpu = out_grad.data.clone().cpu().numpy()
max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
out_grad[:,max_all,range(c)] = 0.0
out_grad[:,min_all,range(c)] = 0.0
for i in range(len(grad_in)):
if i == 0:
return_dics = (out_grad,)
else:
return_dics = return_dics + (grad_in[i],)
return return_dics
attn_ADG_hook = partial(attn_ADG, gamma=0.25)
attn_cait_ADG_hook = partial(attn_cait_ADG, gamma=0.25)
v_ADG_hook = partial(v_ADG, gamma=0.75)
q_ADG_hook = partial(q_ADG, gamma=0.75)
mlp_ADG_hook = partial(mlp_ADG, gamma=0.5)
if self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']:
for i in range(12):
self.model.blocks[i].attn.attn_drop.register_backward_hook(attn_ADG_hook)
self.model.blocks[i].attn.qkv.register_backward_hook(v_ADG_hook)
self.model.blocks[i].mlp.register_backward_hook(mlp_ADG_hook)
elif self.model_name == 'pit_b_224':
for block_ind in range(13):
if block_ind < 3:
transformer_ind = 0
used_block_ind = block_ind
elif block_ind < 9 and block_ind >= 3:
transformer_ind = 1
used_block_ind = block_ind - 3
elif block_ind < 13 and block_ind >= 9:
transformer_ind = 2
used_block_ind = block_ind - 9
self.model.transformers[transformer_ind].blocks[used_block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook)
self.model.transformers[transformer_ind].blocks[used_block_ind].attn.qkv.register_backward_hook(v_ADG_hook)
self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_ADG_hook)
elif self.model_name == 'cait_s24_224':
for block_ind in range(26):
if block_ind < 24:
self.model.blocks[block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook)
self.model.blocks[block_ind].attn.qkv.register_backward_hook(v_ADG_hook)
self.model.blocks[block_ind].mlp.register_backward_hook(mlp_ADG_hook)
elif block_ind > 24:
self.model.blocks_token_only[block_ind-24].attn.attn_drop.register_backward_hook(attn_cait_ADG_hook)
self.model.blocks_token_only[block_ind-24].attn.q.register_backward_hook(q_ADG_hook)
self.model.blocks_token_only[block_ind-24].attn.k.register_backward_hook(v_ADG_hook)
self.model.blocks_token_only[block_ind-24].attn.v.register_backward_hook(v_ADG_hook)
self.model.blocks_token_only[block_ind-24].mlp.register_backward_hook(mlp_ADG_hook)
elif self.model_name == 'visformer_small':
for block_ind in range(8):
if block_ind < 4:
self.model.stage2[block_ind].attn.attn_drop.register_backward_hook(attn_ADG_hook)
self.model.stage2[block_ind].attn.qkv.register_backward_hook(v_ADG_hook)
self.model.stage2[block_ind].mlp.register_backward_hook(mlp_ADG_hook)
elif block_ind >=4:
self.model.stage3[block_ind-4].attn.attn_drop.register_backward_hook(attn_ADG_hook)
self.model.stage3[block_ind-4].attn.qkv.register_backward_hook(v_ADG_hook)
self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_ADG_hook)
def _generate_samples_for_interactions(self, perts, seed):
add_noise_mask = torch.zeros_like(perts)
grid_num_axis = int(self.image_size/self.crop_length)
ids = [i for i in range(self.max_num_batches)]
random.seed(seed)
random.shuffle(ids)
ids = np.array(ids[:self.sample_num_batches])
rows, cols = ids // grid_num_axis, ids % grid_num_axis
flag = 0
for r, c in zip(rows, cols):
add_noise_mask[:,:,r*self.crop_length:(r+1)*self.crop_length,c*self.crop_length:(c+1)*self.crop_length] = 1
add_perturbation = perts * add_noise_mask
return add_perturbation
def forward(self, inps, labels):
inps = inps.cuda()
labels = labels.cuda()
loss = nn.CrossEntropyLoss()
momentum = torch.zeros_like(inps).cuda()
unnorm_inps = self._mul_std_add_mean(inps)
perts = torch.zeros_like(unnorm_inps).cuda()
perts.requires_grad_()
for i in range(self.steps):
outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))
cost = self.loss_flag * loss(outputs, labels).cuda()
cost.backward()
grad = perts.grad.data
grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
grad += momentum*self.decay
momentum = grad
perts.data = self._update_perts(perts.data, grad, self.step_size)
perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
perts.grad.data.zero_()
return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None