152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
from queue import Queue
|
|
from threading import Thread
|
|
|
|
import h5py
|
|
import nltk
|
|
import torch
|
|
import torch.utils.data as data
|
|
import os
|
|
import numpy as np
|
|
import json
|
|
from torch.utils.data import DataLoader
|
|
from prefetch_generator import BackgroundGenerator
|
|
from transformers import BertTokenizer
|
|
|
|
class DataLoaderX(DataLoader):
|
|
|
|
def __iter__(self):
|
|
return BackgroundGenerator(super().__iter__())
|
|
|
|
|
|
class PrecompDataset(data.Dataset):
|
|
"""
|
|
Load precomputed captions and image features
|
|
Possible options: f30k_precomp, coco_precomp
|
|
"""
|
|
def __init__(self, data_path, data_split, vocab):
|
|
print('word txt encoder')
|
|
self.vocab = vocab
|
|
loc = data_path + '/'
|
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
# Captions
|
|
self.captions = []
|
|
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
|
for line in f:
|
|
self.captions.append(line.strip().decode('utf-8'))
|
|
|
|
self.data_split = data_split
|
|
# if self.data_split == 'test':
|
|
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
|
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
|
|
|
# self.tags = []
|
|
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
|
# for line in f:
|
|
# self.tags.append(line.strip().decode('utf-8'))
|
|
|
|
# Image features
|
|
print('loading npy')
|
|
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
|
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
|
print('done load npy')
|
|
self.length = len(self.captions)
|
|
# self.length = 10000
|
|
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
|
if self.images.shape[0] != self.length:
|
|
self.im_div = 5
|
|
else:
|
|
self.im_div = 1
|
|
# the development set for coco is large and so validation would be slow
|
|
if data_split == 'dev':
|
|
self.length = 5000
|
|
|
|
|
|
def __getitem__(self, index):
|
|
# handle the image redundancy
|
|
img_id = int(index/self.im_div)
|
|
image = torch.Tensor(self.images[img_id])
|
|
caption = self.captions[index]
|
|
vocab = self.vocab
|
|
|
|
# caption = self.tokenizer.encode(caption)
|
|
# target = torch.Tensor(caption)
|
|
|
|
# Convert caption (string) to word ids.
|
|
tokens = nltk.tokenize.word_tokenize(
|
|
caption.encode('utf-8').decode('utf-8'))
|
|
caption = []
|
|
caption.append(vocab('<start>'))
|
|
caption.extend([vocab(str(token).lower()) for token in tokens])
|
|
caption.append(vocab('<end>'))
|
|
# assert(len(caption) - 2== len(new_tags))
|
|
target = torch.Tensor(caption)
|
|
# new_tags = torch.Tensor(new_tags)
|
|
|
|
return image, target, index, img_id
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
def collate_fn(data):
|
|
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
|
Args:
|
|
data: list of (image, caption) tuple.
|
|
- image: torch tensor of shape (3, 256, 256).
|
|
- caption: torch tensor of shape (?); variable length.
|
|
|
|
Returns:
|
|
images: torch tensor of shape (batch_size, 3, 256, 256).
|
|
targets: torch tensor of shape (batch_size, padded_length).
|
|
lengths: list; valid length for each padded caption.
|
|
"""
|
|
# Sort a data list by caption length
|
|
data.sort(key=lambda x: len(x[1]), reverse=True)
|
|
|
|
images, captions, ids, img_ids = zip(*data)
|
|
|
|
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
|
images = torch.stack(images, 0)
|
|
|
|
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
|
lengths = torch.LongTensor([len(cap) for cap in captions])
|
|
targets = torch.zeros(len(captions), max(lengths)).long()
|
|
for i, cap in enumerate(captions):
|
|
end = lengths[i]
|
|
targets[i, :end] = cap[:end]
|
|
|
|
return images, targets, lengths, ids
|
|
|
|
|
|
def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
|
|
shuffle=True, num_workers=0):
|
|
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
|
|
|
dset = PrecompDataset(data_path, data_split, vocab)
|
|
# train_sampler = torch.utils.data.distributed.DistributedSampler(dset)
|
|
# if data_split == 'train':
|
|
# data_loader = DataLoader(dataset=dset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler)
|
|
# else:
|
|
print(num_workers)
|
|
data_loader = DataLoaderX(dataset=dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
|
|
|
|
return data_loader
|
|
|
|
|
|
def get_loaders(data_name, vocab, batch_size, workers, opt):
|
|
dpath = os.path.join(opt.data_path, data_name)
|
|
train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
|
|
batch_size, True, workers)
|
|
val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
|
|
batch_size, False, workers)
|
|
return train_loader, val_loader
|
|
|
|
|
|
def get_test_loader2(split_name, data_name, vocab, batch_size,
|
|
workers, opt):
|
|
dpath = os.path.join(opt.data_path, data_name)
|
|
test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
|
|
batch_size, False, workers)
|
|
return test_loader
|