464 lines
18 KiB
Python
464 lines
18 KiB
Python
import os
|
|
import lmdb # Lightning Memory-Mapped Database
|
|
import pickle
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
from PIL import Image
|
|
|
|
import transforms
|
|
|
|
from .batch_sampler import BatchSamplerTripletClassif
|
|
|
|
import json
|
|
import random
|
|
|
|
|
|
def default_items_tf():
|
|
return transforms.Compose([
|
|
transforms.ListDictsToDictLists(),
|
|
transforms.PadTensors(value=0),
|
|
transforms.StackTensors()
|
|
])
|
|
|
|
|
|
def default_image_tf(scale_size, crop_size, mean=None, std=None):
|
|
mean = mean or [0.485, 0.456, 0.406] # resnet imagenet
|
|
std = std or [0.229, 0.224, 0.225]
|
|
transform = transforms.Compose([
|
|
transforms.Resize(scale_size),
|
|
transforms.RandomCrop(crop_size),
|
|
# transforms.CenterCrop(size),
|
|
transforms.ToTensor(), # divide by 255 automatically
|
|
transforms.Normalize(mean=mean, std=std)
|
|
])
|
|
return transform
|
|
|
|
|
|
class Dataset(data.Dataset):
|
|
|
|
def __init__(self, dir_data, split, batch_size, nb_threads, items_tf=default_items_tf):
|
|
super(Dataset, self).__init__()
|
|
self.dir_data = dir_data
|
|
self.split = split
|
|
self.batch_size = batch_size
|
|
self.nb_threads = nb_threads
|
|
self.items_tf = default_items_tf
|
|
|
|
def make_batch_loader(self, shuffle=True):
|
|
# allways shuffle even for valset/testset
|
|
# see testing procedure
|
|
|
|
if Options()['dataset'].get("debug", False):
|
|
return data.DataLoader(self,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.nb_threads,
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
collate_fn=self.items_tf(),
|
|
drop_last=True) # Removing last batch if not full (quick fix accuracy calculation with class 0 only)
|
|
else:
|
|
return data.DataLoader(self,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.nb_threads,
|
|
shuffle=shuffle,
|
|
pin_memory=True,
|
|
collate_fn=self.items_tf(),
|
|
drop_last=True) # Removing last batch if not full (quick fix accuracy calculation with class 0 only)
|
|
|
|
class DatasetLMDB(Dataset):
|
|
|
|
def __init__(self, dir_data, split, batch_size, nb_threads):
|
|
super(DatasetLMDB, self).__init__(dir_data, split, batch_size, nb_threads)
|
|
self.dir_lmdb = os.path.join(self.dir_data, 'data_lmdb')
|
|
|
|
self.path_envs = {}
|
|
self.path_envs['ids'] = os.path.join(self.dir_lmdb, split, 'ids.lmdb')
|
|
self.path_envs['numims'] = os.path.join(self.dir_lmdb, split, 'numims.lmdb')
|
|
self.path_envs['impos'] = os.path.join(self.dir_lmdb, split, 'impos.lmdb')
|
|
self.path_envs['imnames'] = os.path.join(self.dir_lmdb, split, 'imnames.lmdb')
|
|
self.path_envs['ims'] = os.path.join(self.dir_lmdb, split, 'ims.lmdb')
|
|
|
|
self.path_envs['classes'] = os.path.join(self.dir_lmdb, split, 'classes.lmdb')
|
|
|
|
# envs dict = {'ids': , 'classes': }
|
|
self.envs = {}
|
|
self.envs['ids'] = lmdb.open(self.path_envs['ids'], readonly=True, lock=False)
|
|
self.envs['classes'] = lmdb.open(self.path_envs['classes'], readonly=True, lock=False)
|
|
|
|
self.txns = {}
|
|
self.txns['ids'] = self.envs['ids'].begin(write=False, buffers=True)
|
|
self.txns['classes'] = self.envs['classes'].begin(write=False, buffers=True)
|
|
|
|
self.nb_recipes = self.envs['ids'].stat()['entries']
|
|
|
|
|
|
self.path_pkl = os.path.join(self.dir_data, 'classes1M.pkl')
|
|
#https://github.com/torralba-lab/im2recipe/blob/master/pyscripts/bigrams.py#L176
|
|
with open(self.path_pkl, 'rb') as f:
|
|
_ = pickle.load(f) # load the first line/object
|
|
self.classes = pickle.load(f) # load the second line/object
|
|
|
|
self.cname_to_cid = {v:k for k,v in self.classes.items()}
|
|
|
|
def encode(self, value):
|
|
return pickle.dumps(value)
|
|
|
|
def decode(self, bytes_value):
|
|
return pickle.loads(bytes_value)
|
|
|
|
def get(self, index, env_name):
|
|
buf = self.txns[env_name].get(self.encode(index))
|
|
value = self.decode(bytes(buf))
|
|
return value
|
|
|
|
def _load_class(self, index):
|
|
class_id = self.get(index, 'classes') - 1 # lua to python
|
|
return torch.LongTensor([class_id]), self.classes[class_id]
|
|
|
|
def __len__(self):
|
|
return self.nb_recipes
|
|
|
|
def true_nb_images(self):
|
|
return self.envs['imnames'].stat()['entries']
|
|
|
|
|
|
class Images(DatasetLMDB):
|
|
|
|
def __init__(self, dir_data, split, batch_size, nb_threads, image_from='database',
|
|
image_tf=default_image_tf(256, 224), use_vcs=False, get_all_images=False,
|
|
kw_path=None, randkw_p=None, tokenizer=None, aux_kwords=False, aux_kw_path=None, randkw_p_aux=None,
|
|
random_kw=False, random_aux_kw=False):
|
|
|
|
super(Images, self).__init__(dir_data, split, batch_size, nb_threads)
|
|
self.image_tf = image_tf
|
|
self.dir_img = os.path.join(dir_data,'recipe1M', 'images')
|
|
|
|
self.envs['numims'] = lmdb.open(self.path_envs['numims'], readonly=True, lock=False)
|
|
self.envs['impos'] = lmdb.open(self.path_envs['impos'], readonly=True, lock=False)
|
|
self.envs['imnames'] = lmdb.open(self.path_envs['imnames'], readonly=True, lock=False)
|
|
|
|
|
|
self.txns['numims'] = self.envs['numims'].begin(write=False, buffers=True)
|
|
self.txns['impos'] = self.envs['impos'].begin(write=False, buffers=True)
|
|
self.txns['imnames'] = self.envs['imnames'].begin(write=False, buffers=True)
|
|
|
|
self.image_from = image_from
|
|
if self.image_from == 'database':
|
|
self.envs['ims'] = lmdb.open(self.path_envs['ims'], readonly=True, lock=False)
|
|
self.txns['ims'] = self.envs['ims'].begin(write=False, buffers=True)
|
|
self.use_vcs = use_vcs
|
|
self.get_all_images = get_all_images
|
|
self.aux_kwords = aux_kwords
|
|
|
|
self.random_kw=random_kw
|
|
self.random_aux_kw=random_aux_kw
|
|
|
|
|
|
if self.use_vcs:
|
|
# Logger()('Load VCs...')
|
|
self.image_path_to_kws = json.load(open(kw_path,'r'))
|
|
|
|
if self.aux_kwords:
|
|
self.image_path_to_aux_kws = json.load(open(aux_kw_path,'r'))
|
|
|
|
|
|
if split == 'train':
|
|
self.randkw_p = randkw_p
|
|
self.randkw_p_aux = randkw_p_aux
|
|
else:
|
|
self.randkw_p = None
|
|
self.randkw_p_aux = None
|
|
# Logger()('randkw_p...', self.randkw_p)
|
|
|
|
self.dir_img_vcs = '/data/mshukor/data/recipe1m/recipe1M/images'
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
item = self.get_image(index)
|
|
return item
|
|
|
|
def format_path_img(self, raw_path):
|
|
# "recipe1M/images/train/6/b/d/c/6bdca6e490.jpg"
|
|
basename = os.path.basename(raw_path)
|
|
path_img = os.path.join(self.dir_img,
|
|
self.split,
|
|
basename[0],
|
|
basename[1],
|
|
basename[2],
|
|
basename[3],
|
|
basename)
|
|
return path_img
|
|
|
|
def get_image(self, index):
|
|
item = {}
|
|
if self.get_all_images:
|
|
item['samples'] = self._load_image_data(index)
|
|
else:
|
|
item['data'], item['index'], item['path'] = self._load_image_data(index)
|
|
item['class_id'], item['class_name'] = self._load_class(index)
|
|
if self.use_vcs:
|
|
# print(self.image_path_to_kws.keys())
|
|
kw_path = item['path'].replace(self.dir_img, self.dir_img_vcs)
|
|
if self.random_kw:
|
|
rand_index = random.choice(range(len(self)))
|
|
_, _, rand_path = self._load_image_data(rand_index)
|
|
kw_path = rand_path.replace(self.dir_img, self.dir_img_vcs)
|
|
if kw_path in self.image_path_to_kws:
|
|
|
|
kwords = self.image_path_to_kws[kw_path]
|
|
|
|
if self.randkw_p is not None:
|
|
num_kw = int(self.randkw_p * len(kwords))
|
|
kws = random.choices(kwords, k=num_kw)
|
|
else:
|
|
kws = kwords
|
|
|
|
if self.aux_kwords:
|
|
if self.random_aux_kw:
|
|
rand_index = random.choice(range(len(self)))
|
|
_, _, rand_path = self._load_image_data(rand_index)
|
|
kw_path = rand_path.replace(self.dir_img, self.dir_img_vcs)
|
|
if kw_path in self.image_path_to_aux_kws:
|
|
aux_words = self.image_path_to_aux_kws[kw_path]
|
|
|
|
if self.randkw_p_aux is not None:
|
|
num_kw = int(self.randkw_p_aux * len(aux_words))
|
|
aux_kws = random.choices(aux_words, k=num_kw)
|
|
else:
|
|
aux_kws = aux_words
|
|
else:
|
|
aux_kws = ['food', 'food']
|
|
|
|
aux_kws = [' '.join(aux_kws)]
|
|
aux_kws = self.tokenizer(aux_kws, padding='longest', truncation=True, max_length=55, return_tensors="pt") # tokenize kw with bert tokenizer
|
|
item['aux_kwords_ids'] = aux_kws.input_ids[0]
|
|
item['aux_kwords_masks'] = aux_kws.attention_mask[0]
|
|
|
|
else:
|
|
kws = ['food', 'food']
|
|
# Logger()("kws not found", item['path'])
|
|
|
|
kws = [' '.join(kws)]
|
|
kws = self.tokenizer(kws, padding='longest', truncation=True, max_length=55, return_tensors="pt") # tokenize kw with bert tokenizer
|
|
item['kwords_ids'] = kws.input_ids[0]
|
|
item['kwords_masks'] = kws.attention_mask[0]
|
|
|
|
return item
|
|
|
|
def _pil_loader(self, path):
|
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
|
with open(path, 'rb') as f:
|
|
with Image.open(f) as img:
|
|
return img.convert('RGB')
|
|
|
|
def _load_image_data(self, index):
|
|
# select random image from list of images for that sample
|
|
nb_images = self.get(index, 'numims')
|
|
if self.get_all_images:
|
|
images = []
|
|
for im_idx in range(nb_images):
|
|
|
|
index_img = self.get(index, 'impos')[im_idx] - 1 # lua to python
|
|
|
|
path_img = self.format_path_img(self.get(index_img, 'imnames'))
|
|
|
|
if self.image_from == 'pil_loader':
|
|
image_data = self._pil_loader(path_img)
|
|
elif self.image_from == 'database':
|
|
image_data = self.get(index_img, 'ims')
|
|
|
|
if self.image_tf is not None:
|
|
image_data = self.image_tf(image_data)
|
|
item = {}
|
|
item['data'], item['index'], item['path'] = image_data, index_img, path_img
|
|
images.append(item)
|
|
return images
|
|
else:
|
|
try:
|
|
if Options()['dataset'].get("debug", False):
|
|
im_idx = 0
|
|
else:
|
|
im_idx = torch.randperm(nb_images)[0]
|
|
except:
|
|
im_idx = torch.randperm(nb_images)[0]
|
|
|
|
index_img = self.get(index, 'impos')[im_idx] - 1 # lua to python
|
|
|
|
path_img = self.format_path_img(self.get(index_img, 'imnames'))
|
|
|
|
if self.image_from == 'pil_loader':
|
|
image_data = self._pil_loader(path_img)
|
|
elif self.image_from == 'database':
|
|
image_data = self.get(index_img, 'ims')
|
|
|
|
if self.image_tf is not None:
|
|
image_data = self.image_tf(image_data)
|
|
|
|
return image_data, index_img, path_img
|
|
|
|
|
|
class Recipes_raw(DatasetLMDB):
|
|
|
|
def __init__(self, dir_data, split, batch_size, nb_threads):
|
|
super(Recipes_raw, self).__init__(dir_data, split, batch_size, nb_threads)
|
|
|
|
# ~added for visu
|
|
import json
|
|
self.path_layer1 = os.path.join(dir_data, 'text', 'tokenized_layer1.json')
|
|
with open(self.path_layer1, 'r') as f:
|
|
self.layer1 = json.load(f)
|
|
self.envs['ids'] = lmdb.open(self.path_envs['ids'], readonly=True, lock=False)
|
|
# # ~end
|
|
|
|
self.with_titles = Options()['model']['network'].get('with_titles', False)
|
|
self.tokenized_raw_text = Options()['dataset'].get('tokenized_raw_text', False)
|
|
self.max_instrs_len = Options()['dataset'].get('max_instrs_len', 20)
|
|
self.max_ingrs_len = Options()['dataset'].get('max_ingrs_len', 15)
|
|
self.max_instrs = Options()['dataset'].get('max_instrs', 20)
|
|
self.max_ingrs = Options()['dataset'].get('max_ingrs', 20)
|
|
|
|
self.remove_list = Options()['dataset'].get('remove_list', None)
|
|
self.interchange_ingrd_instr = Options()['dataset'].get('interchange_ingrd_instr', None)
|
|
# Logger()('recipe elements to remove:', self.remove_list)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
item = self.get_recipe(index)
|
|
return item
|
|
|
|
def get_recipe(self, index):
|
|
item = {}
|
|
item['class_id'], item['class_name'] = self._load_class(index)
|
|
|
|
item['index'] = index
|
|
# ~added for visu
|
|
item['ids'] = self.get(index, 'ids')
|
|
item['layer1'] = self.layer1[item['ids']]
|
|
|
|
item['layer1']['title'] = torch.LongTensor(item['layer1']['title'])
|
|
if self.remove_list is not None and 'title' in self.remove_list:
|
|
item['layer1']['title'] = torch.LongTensor([167838, 178987, 59198]) # [start ukn end] tokens
|
|
|
|
if self.remove_list is not None and 'ingredients' in self.remove_list:
|
|
item['layer1']['ingredients'] = torch.LongTensor([[167838, 178987, 59198]])
|
|
else:
|
|
tokenized_ingrs = item['layer1']['ingredients'][:self.max_ingrs]
|
|
tokenized_ingrs = [l[:self.max_ingrs_len] for l in tokenized_ingrs]
|
|
max_len = max([len(l) for l in tokenized_ingrs])
|
|
tokenized_ingrs = [l + (max_len - len(l))*[0] for l in tokenized_ingrs]
|
|
item['layer1']['ingredients'] = torch.LongTensor(tokenized_ingrs)
|
|
|
|
if self.remove_list is not None and 'instructions' in self.remove_list:
|
|
item['layer1']['instructions'] = torch.LongTensor([[167838, 178987, 59198]])
|
|
else:
|
|
tokenized_instrs = item['layer1']['instructions'][:self.max_instrs]
|
|
tokenized_instrs = [l[:self.max_instrs_len] for l in tokenized_instrs]
|
|
max_len = max([len(l) for l in tokenized_instrs])
|
|
tokenized_instrs = [l + (max_len - len(l))*[0] for l in tokenized_instrs]
|
|
item['layer1']['instructions'] = torch.LongTensor(tokenized_instrs)
|
|
|
|
if self.interchange_ingrd_instr is not None:
|
|
tmp = item['layer1']['instructions'].clone()
|
|
item['layer1']['instructions'] = item['layer1']['ingredients'].clone()
|
|
item['layer1']['ingredients'] = tmp
|
|
|
|
# ~end
|
|
return item
|
|
|
|
|
|
|
|
class Recipe1M(DatasetLMDB):
|
|
|
|
def __init__(self, dir_data, split, batch_size=100, nb_threads=4, freq_mismatch=0.,
|
|
batch_sampler='triplet_classif',
|
|
image_from='database', image_tf=default_image_tf(256, 224),
|
|
use_vcs=False, kw_path=None, randkw_p=None, tokenizer=None,
|
|
aux_kwords=False, aux_kw_path=None, randkw_p_aux=None,
|
|
random_kw=False, random_aux_kw=False):
|
|
super(Recipe1M, self).__init__(dir_data, split, batch_size, nb_threads)
|
|
|
|
|
|
self.images_dataset = Images(dir_data, split, batch_size, nb_threads, image_from=image_from,
|
|
image_tf=image_tf, use_vcs=use_vcs, kw_path=kw_path, randkw_p=randkw_p, tokenizer=tokenizer,
|
|
aux_kwords=aux_kwords, aux_kw_path=aux_kw_path, randkw_p_aux=randkw_p_aux,
|
|
random_kw=random_kw, random_aux_kw=random_aux_kw)
|
|
self.tokenized_raw_text = Options()['dataset'].get('tokenized_raw_text', False)
|
|
self.dataset_revamping = Options()['dataset'].get('dataset_revamping', False)
|
|
if self.tokenized_raw_text:
|
|
self.recipes_dataset = Recipes_raw(dir_data, split, batch_size, nb_threads)
|
|
else:
|
|
raise NotImplementedError("Only raw text is supported")
|
|
self.freq_mismatch = freq_mismatch
|
|
self.batch_sampler = batch_sampler
|
|
|
|
if self.split == 'train' and self.batch_sampler == 'triplet_classif':
|
|
self.indices_by_class = self._make_indices_by_class()
|
|
|
|
|
|
|
|
def _make_indices_by_class(self):
|
|
# Logger()('Calculate indices by class...')
|
|
indices_by_class = [[] for class_id in range(len(self.classes))]
|
|
for index in range(len(self.recipes_dataset)):
|
|
class_id = self._load_class(index)[0][0] # bcause (class_id, class_name) and class_id is a Tensor
|
|
indices_by_class[class_id].append(index)
|
|
# Logger()('Done!')
|
|
return indices_by_class
|
|
|
|
def make_batch_loader(self, shuffle=True):
|
|
if self.split in ['val', 'test'] or self.batch_sampler == 'random':
|
|
if Options()['dataset'].get("debug", False):
|
|
batch_loader = super(Recipe1M, self).make_batch_loader(shuffle=False)
|
|
else:
|
|
batch_loader = super(Recipe1M, self).make_batch_loader(shuffle=shuffle)
|
|
# Logger()('Dataset will be sampled with "random" batch_sampler.')
|
|
elif self.batch_sampler == 'triplet_classif':
|
|
batch_sampler = BatchSamplerTripletClassif(
|
|
self.indices_by_class,
|
|
self.batch_size,
|
|
pc_noclassif=0.5,
|
|
nb_indices_same_class=2)
|
|
batch_loader = data.DataLoader(self,
|
|
num_workers=self.nb_threads,
|
|
batch_sampler=batch_sampler,
|
|
pin_memory=True,
|
|
collate_fn=self.items_tf())
|
|
# Logger()('Dataset will be sampled with "triplet_classif" batch_sampler.')
|
|
else:
|
|
raise ValueError()
|
|
return batch_loader
|
|
|
|
def __getitem__(self, index):
|
|
#ids = self.data['ids'][index]
|
|
item = {}
|
|
item['index'] = index
|
|
item['recipe'] = self.recipes_dataset[index]
|
|
|
|
# The porb of returning mismatch pairs
|
|
if self.freq_mismatch > 0:
|
|
is_match = torch.rand(1)[0] > self.freq_mismatch
|
|
else:
|
|
is_match = True
|
|
|
|
if is_match:
|
|
item['image'] = self.images_dataset[index] # Return the image of this index
|
|
item['match'] = torch.FloatTensor([1]) # Label for matching
|
|
else:
|
|
n_index = int(torch.rand(1)[0] * len(self))
|
|
item['image'] = self.images_dataset[n_index] # Return a random image
|
|
item['match'] = torch.FloatTensor([-1]) # Label for mismatching
|
|
return item
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass
|
|
# Logger(Options()['logs']['dir'])('lol')
|
|
|
|
|