102 lines
3.8 KiB
Python
102 lines
3.8 KiB
Python
import torch.utils.data as data
|
|
import os
|
|
import torch
|
|
import json
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
import math
|
|
import glob
|
|
|
|
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
|
from utils import ROOT_PATH
|
|
|
|
name_to_class_ids_file = os.path.join(ROOT_PATH, 'image_name_to_class_id_and_name.json')
|
|
INPUT_SIZE = (3, 224, 224)
|
|
INTERPOLATION = 'bicubic'
|
|
DEFAULT_CROP_PCT = 0.875
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
|
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
|
|
|
def params(model_name):
|
|
if model_name in ['vit_deit_base_distilled_patch16_224','levit_256','pit_b_224','cait_s24_224','convit_base', 'visformer_small', 'deit_base_distilled_patch16_224']:
|
|
params = {'mean': IMAGENET_DEFAULT_MEAN,
|
|
'std': IMAGENET_DEFAULT_STD,
|
|
'interpolation': INTERPOLATION,
|
|
'crop_pct':0.9}
|
|
else:
|
|
params = {'mean': IMAGENET_INCEPTION_MEAN,
|
|
'std': IMAGENET_INCEPTION_STD,
|
|
'interpolation': INTERPOLATION,
|
|
'crop_pct': 0.9}
|
|
return params
|
|
|
|
def transforms_imagenet_wo_resize(params):
|
|
tfl = [
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=torch.tensor(params['mean']),
|
|
std=torch.tensor(params['std']))
|
|
]
|
|
return transforms.Compose(tfl)
|
|
|
|
class AdvDataset(data.Dataset):
|
|
def __init__(self, model_name, adv_path):
|
|
self.transform = transforms_imagenet_wo_resize(params(model_name))
|
|
paths = glob.glob(os.path.join(adv_path, '*.png'))
|
|
paths = [i.split('/')[-1] for i in paths]
|
|
print ('Using ', len(paths))
|
|
paths = [i.strip() for i in paths]
|
|
self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
|
|
self.paths = [os.path.join(adv_path, i) for i in paths]
|
|
self.model_name = model_name
|
|
|
|
with open(name_to_class_ids_file, 'r') as ipt:
|
|
self.json_info = json.load(ipt)
|
|
|
|
def __len__(self):
|
|
return len(self.paths)
|
|
|
|
def __getitem__(self, index):
|
|
path = self.paths[index]
|
|
query_path = self.query_paths[index]
|
|
class_id = self.json_info[query_path]['class_id']
|
|
class_name = self.json_info[query_path]['class_name']
|
|
image_name = path.split('/')[-1]
|
|
img = Image.open(path).convert('RGB')
|
|
if self.model_name == "tf2torch_resnet_v2_101":
|
|
img = transforms.Resize((299,299))(img)
|
|
img = transforms.Compose([transforms.ToTensor()])(img)
|
|
else:
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, class_id, class_name, image_name
|
|
|
|
class CNNDataset(data.Dataset):
|
|
def __init__(self, model_name, adv_path):
|
|
self.transform = transforms_imagenet_wo_resize(params(model_name))
|
|
paths = glob.glob(os.path.join(adv_path, '*.png'))
|
|
paths = [i.split('/')[-1] for i in paths]
|
|
print ('Using ', len(paths))
|
|
paths = [i.strip() for i in paths]
|
|
self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
|
|
self.paths = [os.path.join(adv_path, i) for i in paths]
|
|
|
|
with open(name_to_class_ids_file, 'r') as ipt:
|
|
self.json_info = json.load(ipt)
|
|
|
|
def __len__(self):
|
|
return len(self.paths)
|
|
|
|
def __getitem__(self, index):
|
|
path = self.paths[index]
|
|
query_path = self.query_paths[index]
|
|
class_id = self.json_info[query_path]['class_id']
|
|
class_name = self.json_info[query_path]['class_name']
|
|
image_name = path.split('/')[-1]
|
|
img = Image.open(path).convert('RGB')
|
|
img = transforms.Resize((299,299))(img)
|
|
img = transforms.Compose([transforms.ToTensor()])(img)
|
|
return img, class_id, class_name, image_name
|