import torch.utils.data as data import os import torch import json from PIL import Image from torchvision import transforms import math import glob from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor from utils import ROOT_PATH name_to_class_ids_file = os.path.join(ROOT_PATH, 'image_name_to_class_id_and_name.json') INPUT_SIZE = (3, 224, 224) INTERPOLATION = 'bicubic' DEFAULT_CROP_PCT = 0.875 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) def params(model_name): if model_name in ['vit_deit_base_distilled_patch16_224','levit_256','pit_b_224','cait_s24_224','convit_base', 'visformer_small', 'deit_base_distilled_patch16_224']: params = {'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'interpolation': INTERPOLATION, 'crop_pct':0.9} else: params = {'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'interpolation': INTERPOLATION, 'crop_pct': 0.9} return params def transforms_imagenet_wo_resize(params): tfl = [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(params['mean']), std=torch.tensor(params['std'])) ] return transforms.Compose(tfl) class AdvDataset(data.Dataset): def __init__(self, model_name, adv_path): self.transform = transforms_imagenet_wo_resize(params(model_name)) paths = glob.glob(os.path.join(adv_path, '*.png')) paths = [i.split('/')[-1] for i in paths] print ('Using ', len(paths)) paths = [i.strip() for i in paths] self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths] self.paths = [os.path.join(adv_path, i) for i in paths] self.model_name = model_name with open(name_to_class_ids_file, 'r') as ipt: self.json_info = json.load(ipt) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] query_path = self.query_paths[index] class_id = self.json_info[query_path]['class_id'] class_name = self.json_info[query_path]['class_name'] image_name = path.split('/')[-1] img = Image.open(path).convert('RGB') if self.model_name == "tf2torch_resnet_v2_101": img = transforms.Resize((299,299))(img) img = transforms.Compose([transforms.ToTensor()])(img) else: if self.transform is not None: img = self.transform(img) return img, class_id, class_name, image_name class CNNDataset(data.Dataset): def __init__(self, model_name, adv_path): self.transform = transforms_imagenet_wo_resize(params(model_name)) paths = glob.glob(os.path.join(adv_path, '*.png')) paths = [i.split('/')[-1] for i in paths] print ('Using ', len(paths)) paths = [i.strip() for i in paths] self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths] self.paths = [os.path.join(adv_path, i) for i in paths] with open(name_to_class_ids_file, 'r') as ipt: self.json_info = json.load(ipt) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] query_path = self.query_paths[index] class_id = self.json_info[query_path]['class_id'] class_name = self.json_info[query_path]['class_name'] image_name = path.split('/')[-1] img = Image.open(path).convert('RGB') img = transforms.Resize((299,299))(img) img = transforms.Compose([transforms.ToTensor()])(img) return img, class_id, class_name, image_name