xiaolin_code/dataset.py

102 lines
3.8 KiB
Python
Raw Normal View History

2024-07-04 17:09:13 +08:00
import torch.utils.data as data
import os
import torch
import json
from PIL import Image
from torchvision import transforms
import math
import glob
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from utils import ROOT_PATH
name_to_class_ids_file = os.path.join(ROOT_PATH, 'image_name_to_class_id_and_name.json')
INPUT_SIZE = (3, 224, 224)
INTERPOLATION = 'bicubic'
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
def params(model_name):
if model_name in ['vit_deit_base_distilled_patch16_224','levit_256','pit_b_224','cait_s24_224','convit_base', 'visformer_small', 'deit_base_distilled_patch16_224']:
params = {'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'interpolation': INTERPOLATION,
'crop_pct':0.9}
else:
params = {'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'interpolation': INTERPOLATION,
'crop_pct': 0.9}
return params
def transforms_imagenet_wo_resize(params):
tfl = [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(params['mean']),
std=torch.tensor(params['std']))
]
return transforms.Compose(tfl)
class AdvDataset(data.Dataset):
def __init__(self, model_name, adv_path):
self.transform = transforms_imagenet_wo_resize(params(model_name))
paths = glob.glob(os.path.join(adv_path, '*.png'))
paths = [i.split('/')[-1] for i in paths]
print ('Using ', len(paths))
paths = [i.strip() for i in paths]
self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
self.paths = [os.path.join(adv_path, i) for i in paths]
self.model_name = model_name
with open(name_to_class_ids_file, 'r') as ipt:
self.json_info = json.load(ipt)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
query_path = self.query_paths[index]
class_id = self.json_info[query_path]['class_id']
class_name = self.json_info[query_path]['class_name']
image_name = path.split('/')[-1]
img = Image.open(path).convert('RGB')
if self.model_name == "tf2torch_resnet_v2_101":
img = transforms.Resize((299,299))(img)
img = transforms.Compose([transforms.ToTensor()])(img)
else:
if self.transform is not None:
img = self.transform(img)
return img, class_id, class_name, image_name
class CNNDataset(data.Dataset):
def __init__(self, model_name, adv_path):
self.transform = transforms_imagenet_wo_resize(params(model_name))
paths = glob.glob(os.path.join(adv_path, '*.png'))
paths = [i.split('/')[-1] for i in paths]
print ('Using ', len(paths))
paths = [i.strip() for i in paths]
self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
self.paths = [os.path.join(adv_path, i) for i in paths]
with open(name_to_class_ids_file, 'r') as ipt:
self.json_info = json.load(ipt)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
query_path = self.query_paths[index]
class_id = self.json_info[query_path]['class_id']
class_name = self.json_info[query_path]['class_name']
image_name = path.split('/')[-1]
img = Image.open(path).convert('RGB')
img = transforms.Resize((299,299))(img)
img = transforms.Compose([transforms.ToTensor()])(img)
return img, class_id, class_name, image_name