234 lines
8.5 KiB
Python
234 lines
8.5 KiB
Python
|
from queue import Queue
|
||
|
from threading import Thread
|
||
|
|
||
|
import h5py
|
||
|
import nltk
|
||
|
import torch
|
||
|
import torch.utils.data as data
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import json
|
||
|
from torch.utils.data import DataLoader
|
||
|
from prefetch_generator import BackgroundGenerator
|
||
|
from transformers import BertTokenizer
|
||
|
|
||
|
class DataLoaderX(DataLoader):
|
||
|
|
||
|
def __iter__(self):
|
||
|
return BackgroundGenerator(super().__iter__())
|
||
|
|
||
|
class PrecompShuffleDataset(data.Dataset):
|
||
|
"""
|
||
|
Load precomputed captions and image features
|
||
|
Possible options: f30k_precomp, coco_precomp
|
||
|
"""
|
||
|
def __init__(self, data_path, data_split, vocab):
|
||
|
print('word txt encoder')
|
||
|
self.vocab = vocab
|
||
|
loc = data_path + '/'
|
||
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||
|
|
||
|
# Captions
|
||
|
self.captions = []
|
||
|
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
||
|
for line in f:
|
||
|
self.captions.append(line.strip().decode('utf-8'))
|
||
|
|
||
|
self.data_split = data_split
|
||
|
# if self.data_split == 'test':
|
||
|
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
||
|
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
||
|
|
||
|
# self.tags = []
|
||
|
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
||
|
# for line in f:
|
||
|
# self.tags.append(line.strip().decode('utf-8'))
|
||
|
|
||
|
# Image features
|
||
|
print('loading npy')
|
||
|
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
||
|
#print(len(self.images), len(self.captions))
|
||
|
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
||
|
print('done load npy')
|
||
|
self.length = len(self.captions)
|
||
|
# self.length = 10000
|
||
|
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||
|
if self.images.shape[0] != self.length:
|
||
|
self.im_div = 5
|
||
|
else:
|
||
|
self.im_div = 1
|
||
|
# the development set for coco is large and so validation would be slow
|
||
|
if data_split == 'shuffle_dev':
|
||
|
self.length = 20000
|
||
|
self.im_div = 4
|
||
|
if data_split == 'shuffle_train':
|
||
|
#self.length = 20000
|
||
|
self.im_div = 20
|
||
|
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
# handle the image redundancy
|
||
|
img_id = int(index/self.im_div)
|
||
|
image = torch.Tensor(self.images[img_id])
|
||
|
caption = self.captions[index]
|
||
|
vocab = self.vocab
|
||
|
|
||
|
caption = caption.replace('.', '.[SEP]')[:-6]
|
||
|
caption = self.tokenizer.encode(caption)
|
||
|
target = torch.Tensor(caption)
|
||
|
|
||
|
# Convert caption (string) to word ids.
|
||
|
# tokens = nltk.tokenize.word_tokenize(
|
||
|
# caption.encode('utf-8').decode('utf-8'))
|
||
|
# caption = []
|
||
|
# caption.append(vocab('<start>'))
|
||
|
# caption.extend([vocab(str(token).lower()) for token in tokens])
|
||
|
# caption.append(vocab('<end>'))
|
||
|
# # assert(len(caption) - 2== len(new_tags))
|
||
|
# target = torch.Tensor(caption)
|
||
|
# # new_tags = torch.Tensor(new_tags)
|
||
|
|
||
|
return image, target, index, img_id
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.length
|
||
|
|
||
|
class PrecompDataset(data.Dataset):
|
||
|
"""
|
||
|
Load precomputed captions and image features
|
||
|
Possible options: f30k_precomp, coco_precomp
|
||
|
"""
|
||
|
def __init__(self, data_path, data_split, vocab):
|
||
|
print('word txt encoder')
|
||
|
self.vocab = vocab
|
||
|
loc = data_path + '/'
|
||
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||
|
|
||
|
# Captions
|
||
|
self.captions = []
|
||
|
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
|
||
|
for line in f:
|
||
|
self.captions.append(line.strip().decode('utf-8'))
|
||
|
|
||
|
self.data_split = data_split
|
||
|
# if self.data_split == 'test':
|
||
|
# self.bbox = np.load(loc + '%s_ims_bbx.npy' % data_split)
|
||
|
# self.sizes = np.load(loc + '%s_ims_size.npy' % data_split, allow_pickle=True)
|
||
|
|
||
|
# self.tags = []
|
||
|
# with open(loc + '%s_tags_new.txt' % data_split, 'rb') as f:
|
||
|
# for line in f:
|
||
|
# self.tags.append(line.strip().decode('utf-8'))
|
||
|
|
||
|
# Image features
|
||
|
print('loading npy')
|
||
|
self.images = np.load(loc+'%s_ims.npy' % data_split, mmap_mode = 'r')
|
||
|
#self.images = np.load(loc + '%s_ims.npy' % data_split)
|
||
|
print('done load npy')
|
||
|
self.length = len(self.captions)
|
||
|
# self.length = 10000
|
||
|
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
|
||
|
if self.images.shape[0] != self.length:
|
||
|
self.im_div = 5
|
||
|
else:
|
||
|
self.im_div = 1
|
||
|
# the development set for coco is large and so validation would be slow
|
||
|
if data_split == 'dev':
|
||
|
self.length = 5000
|
||
|
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
# handle the image redundancy
|
||
|
img_id = int(index/self.im_div)
|
||
|
image = torch.Tensor(self.images[img_id])
|
||
|
caption = self.captions[index]
|
||
|
vocab = self.vocab
|
||
|
|
||
|
caption = self.tokenizer.encode(caption)
|
||
|
target = torch.Tensor(caption)
|
||
|
|
||
|
# Convert caption (string) to word ids.
|
||
|
# tokens = nltk.tokenize.word_tokenize(
|
||
|
# caption.encode('utf-8').decode('utf-8'))
|
||
|
# caption = []
|
||
|
# caption.append(vocab('<start>'))
|
||
|
# caption.extend([vocab(str(token).lower()) for token in tokens])
|
||
|
# caption.append(vocab('<end>'))
|
||
|
# # assert(len(caption) - 2== len(new_tags))
|
||
|
# target = torch.Tensor(caption)
|
||
|
# # new_tags = torch.Tensor(new_tags)
|
||
|
|
||
|
return image, target, index, img_id
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.length
|
||
|
|
||
|
|
||
|
def collate_fn(data):
|
||
|
"""Build mini-batch tensors from a list of (image, caption) tuples.
|
||
|
Args:
|
||
|
data: list of (image, caption) tuple.
|
||
|
- image: torch tensor of shape (3, 256, 256).
|
||
|
- caption: torch tensor of shape (?); variable length.
|
||
|
|
||
|
Returns:
|
||
|
images: torch tensor of shape (batch_size, 3, 256, 256).
|
||
|
targets: torch tensor of shape (batch_size, padded_length).
|
||
|
lengths: list; valid length for each padded caption.
|
||
|
"""
|
||
|
# Sort a data list by caption length
|
||
|
data.sort(key=lambda x: len(x[1]), reverse=True)
|
||
|
|
||
|
images, captions, ids, img_ids = zip(*data)
|
||
|
|
||
|
# Merge images (convert tuple of 3D tensor to 4D tensor)
|
||
|
images = torch.stack(images, 0)
|
||
|
|
||
|
# Merget captions (convert tuple of 1D tensor to 2D tensor)
|
||
|
lengths = torch.LongTensor([len(cap) for cap in captions])
|
||
|
#print(lengths.dtype, lengths.shape)
|
||
|
targets = torch.zeros(len(captions), max(lengths)).long()
|
||
|
for i, cap in enumerate(captions):
|
||
|
end = lengths[i]
|
||
|
targets[i, :end] = cap[:end]
|
||
|
img_lengths = [len(image) for image in images]
|
||
|
img_lengths = torch.Tensor(img_lengths)
|
||
|
return images, img_lengths, targets, lengths, ids
|
||
|
|
||
|
|
||
|
def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
|
||
|
shuffle=True, num_workers=0):
|
||
|
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
|
||
|
|
||
|
dset = PrecompDataset(data_path, data_split, vocab)
|
||
|
#dset = PrecompShuffleDataset(data_path, data_split, vocab)
|
||
|
# train_sampler = torch.utils.data.distributed.DistributedSampler(dset)
|
||
|
# if data_split == 'train':
|
||
|
# data_loader = DataLoader(dataset=dset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler)
|
||
|
# else:
|
||
|
print(num_workers)
|
||
|
data_loader = DataLoaderX(dataset=dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
|
||
|
|
||
|
return data_loader
|
||
|
|
||
|
|
||
|
def get_loaders(data_name, vocab, batch_size, workers, opt):
|
||
|
dpath = os.path.join(opt.data_path, data_name)
|
||
|
train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
|
||
|
batch_size, True, workers)
|
||
|
val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
|
||
|
batch_size, False, workers)
|
||
|
# train_loader = get_precomp_loader(dpath, 'shuffle_train', vocab, opt,
|
||
|
# batch_size, True, workers)
|
||
|
# val_loader = get_precomp_loader(dpath, 'shuffle_dev', vocab, opt,
|
||
|
# batch_size, False, workers)
|
||
|
return train_loader, val_loader
|
||
|
|
||
|
|
||
|
def get_test_loader(split_name, data_name, vocab, batch_size,
|
||
|
workers, opt):
|
||
|
dpath = os.path.join(opt.data_path, data_name)
|
||
|
test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
|
||
|
batch_size, False, workers)
|
||
|
return test_loader
|