fujie_code/train_patch.py

226 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Training code for Adversarial patch training
"""
import PIL
from torch.utils.tensorboard import SummaryWriter
# import load_data
from tqdm import tqdm
from load_data import * # 可能导致多次导入问题?
import gc
import matplotlib.pyplot as plt
from torch import autograd
from torchvision import transforms
import subprocess
import patch_config
import sys
import time
from yolo import YOLO
class PatchTrainer(object):
def __init__(self, mode):
self.config = patch_config.patch_configs[mode]() # 获取对应的配置类
# self.darknet_model = Darknet(self.config.cfgfile) # 加载yolo模型
# self.darknet_model.load_weights(self.config.weightfile) # 默认 YOLOv2 MS COCO weights person编号是0
self.darknet_model = YOLO().net
self.darknet_model = self.darknet_model.eval().cuda() # TODO: Why eval?
self.patch_applier = PatchApplier().cuda() # 对图像应用对抗补丁
self.patch_transformer = PatchTransformer().cuda() # 变换补丁到指定大小并产生抖动
# self.prob_extractor = MaxProbExtractor(0, 80, self.config).cuda() # 提取最大类别概率
self.prob_extractor = MaxProbExtractor(0, 1, self.config).cuda() # 提取最大类别概率
self.nps_calculator = NPSCalculator(self.config.printfile, self.config.patch_size).cuda() # 不可打印分数
self.total_variation = TotalVariation().cuda() # 计算补丁的所有变化程度
self.writer = self.init_tensorboard(mode)
def init_tensorboard(self, name=None):
subprocess.Popen(['tensorboard', '--logdir=runs'])
if name is not None:
time_str = time.strftime("%Y%m%d-%H%M%S")
return SummaryWriter(f'runs/{time_str}_{name}')
else:
return SummaryWriter()
def train(self):
"""
Optimize a patch to generate an adversarial example.
:return: Nothing
"""
img_size = self.darknet_model.height # 416
# print('batch_size:',batch_size)
batch_size = self.config.batch_size # 8
n_epochs = 200
# n_epochs = 5
# max_lab = 20 # label的最大长度
max_lab = 8
time_str = time.strftime("%Y%m%d-%H%M%S")
# Generate stating point
# adv_patch_cpu = self.generate_patch("gray") # 生成一个灰图初始化为0.5
adv_patch_cpu = self.read_image("saved_patches/patchnew0.jpg")
adv_patch_cpu.requires_grad_(True)
train_loader = torch.utils.data.DataLoader(
InriaDataset(self.config.img_dir, self.config.lab_dir, max_lab, img_size,
shuffle=True),
batch_size=batch_size,
shuffle=True,
num_workers=0) # 与 from load_data import * 搭配导致多少导入?
self.epoch_length = len(train_loader)
print(f'One epoch is {len(train_loader)}')
optimizer = optim.Adam([adv_patch_cpu], lr=self.config.start_learning_rate, amsgrad=True) # 更新的是那个补丁
scheduler = self.config.scheduler_factory(optimizer) # ICLR-2018年最佳论文提出的Adam改进版Amsgrad
et0 = time.time()
for epoch in range(n_epochs):
ep_det_loss = 0
ep_nps_loss = 0
ep_tv_loss = 0
ep_loss = 0
bt0 = time.time()
for i_batch, (img_batch, lab_batch) in tqdm(enumerate(train_loader), desc=f'Running epoch {epoch}',
total=self.epoch_length):
with autograd.detect_anomaly(): # 1.运行前向时开启异常检测功能,则在反向时会打印引起反向失败的前向操作堆栈 2.反向计算出现“nan”时引发异常
img_batch = img_batch.cuda() # 8, 3, 416, 416
lab_batch = lab_batch.cuda() # 8, 14, 5 为什么要把人数的标签补到14?
# print('TRAINING EPOCH %i, BATCH %i'%(epoch, i_batch))
adv_patch = adv_patch_cpu.cuda() # 3, 300, 300
adv_batch_t = self.patch_transformer(adv_patch, lab_batch, img_size, do_rotate=True, rand_loc=False)
p_img_batch = self.patch_applier(img_batch, adv_batch_t)
p_img_batch = F.interpolate(p_img_batch,
(self.darknet_model.height, self.darknet_model.width)) # 确保和图片大小一致
# print('++++++++++++p_img_batch:+++++++++++++',p_img_batch.shape)
img = p_img_batch[1, :, :, ]
img = transforms.ToPILImage()(img.detach().cpu())
# img.show()
outputs = self.darknet_model(p_img_batch) # 输入83416416 输出8425 13 13 ,其中425是5*(5+80)
max_prob = 0
nps = 0
tv = 0
for l in range(len(outputs)): # 三组不同分辨率大小的输出特征分别计算
output = outputs[l]
max_prob += self.prob_extractor(output)
nps += self.nps_calculator(adv_patch)
tv += self.total_variation(adv_patch)
nps_loss = nps * 0.01
tv_loss = tv * 2.5
det_loss = torch.mean(max_prob) # 把人的置值度当成损失
loss = det_loss + nps_loss + torch.max(tv_loss, torch.tensor(0.1).cuda())
ep_det_loss += det_loss.detach().cpu().numpy()
ep_nps_loss += nps_loss.detach().cpu().numpy()
ep_tv_loss += tv_loss.detach().cpu().numpy()
ep_loss += loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
adv_patch_cpu.data.clamp_(0, 1) # keep patch in image range
bt1 = time.time()
if i_batch % 5 == 0:
iteration = self.epoch_length * epoch + i_batch
self.writer.add_scalar('total_loss', loss.detach().cpu().numpy(), iteration)
self.writer.add_scalar('loss/det_loss', det_loss.detach().cpu().numpy(), iteration)
self.writer.add_scalar('loss/nps_loss', nps_loss.detach().cpu().numpy(), iteration)
self.writer.add_scalar('loss/tv_loss', tv_loss.detach().cpu().numpy(), iteration)
self.writer.add_scalar('misc/epoch', epoch, iteration)
self.writer.add_scalar('misc/learning_rate', optimizer.param_groups[0]["lr"], iteration)
self.writer.add_image('patch', adv_patch_cpu, iteration)
if i_batch + 1 >= len(train_loader):
print('\n')
else:
del adv_batch_t, output, max_prob, det_loss, p_img_batch, nps_loss, tv_loss, loss
torch.cuda.empty_cache()
bt0 = time.time()
et1 = time.time()
ep_det_loss = ep_det_loss / len(train_loader)
ep_nps_loss = ep_nps_loss / len(train_loader)
ep_tv_loss = ep_tv_loss / len(train_loader)
ep_loss = ep_loss / len(train_loader)
# im = transforms.ToPILImage('RGB')(adv_patch_cpu)
# plt.imshow(im)
# plt.savefig(f'pics/{time_str}_{self.config.patch_name}_{epoch}.png')
scheduler.step(ep_loss)
if True:
print(' EPOCH NR: ', epoch),
print('EPOCH LOSS: ', ep_loss)
print(' DET LOSS: ', ep_det_loss)
print(' NPS LOSS: ', ep_nps_loss)
print(' TV LOSS: ', ep_tv_loss)
print('EPOCH TIME: ', et1 - et0)
# im = transforms.ToPILImage('RGB')(adv_patch_cpu)
# plt.imshow(im)
# plt.show()
# im.save("saved_patches/patchnew1.jpg")
im = transforms.ToPILImage('RGB')(adv_patch_cpu)
if epoch >= 3:
im.save(f"saved_patches/patchnew1_t1_{epoch}_{time_str}.jpg")
del adv_batch_t, output, max_prob, det_loss, p_img_batch, nps_loss, tv_loss, loss
torch.cuda.empty_cache()
et0 = time.time()
def generate_patch(self, type):
"""
Generate a random patch as a starting point for optimization.
:param type: Can be 'gray' or 'random'. Whether or not generate a gray or a random patch.
:return:
"""
if type == 'gray':
adv_patch_cpu = torch.full((3, self.config.patch_size, self.config.patch_size), 0.5)
elif type == 'random':
adv_patch_cpu = torch.rand((3, self.config.patch_size, self.config.patch_size))
return adv_patch_cpu
def read_image(self, path):
"""
Read an input image to be used as a patch
:param path: Path to the image to be read.
:return: Returns the transformed patch as a pytorch Tensor.
"""
patch_img = Image.open(path).convert('RGB')
tf = transforms.Resize((self.config.patch_size, self.config.patch_size))
patch_img = tf(patch_img)
tf = transforms.ToTensor()
adv_patch_cpu = tf(patch_img)
return adv_patch_cpu
def main():
if len(sys.argv) != 2:
print('You need to supply (only) a configuration mode.')
print('Possible modes are:')
print(patch_config.patch_configs) # 一般传入paper_obj
# print('sys.argv:',sys.argv)
trainer = PatchTrainer(sys.argv[1])
trainer.train()
if __name__ == '__main__':
main()