Graduation_Project/WZM/data.py

146 lines
5.2 KiB
Python

import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import os
import nltk
import numpy as np
import yaml
import argparse
import utils
from vocab import deserialize_vocab
from PIL import Image
class PrecompDataset(data.Dataset):
"""
Load precomputed captions and image features
"""
def __init__(self, data_split, vocab, opt):
self.vocab = vocab
self.loc = opt['dataset']['data_path']
self.img_path = opt['dataset']['image_path']
# Captions
self.images = []
self.captions = []
self.maxlength = 0
if data_split != 'test':
with open(self.loc+'%s_caps_verify.txt' % data_split, 'rb') as f:
for line in f:
self.captions.append(line.strip())
with open(self.loc + '%s_filename_verify.txt' % data_split, 'rb') as f:
for line in f:
self.images.append(line.strip())
else:
with open(self.loc + '%s_caps.txt' % data_split, 'rb') as f:
for line in f:
self.captions.append(line.strip())
with open(self.loc + '%s_filename.txt' % data_split, 'rb') as f:
for line in f:
self.images.append(line.strip())
self.length = len(self.captions)
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
if len(self.images) != self.length:
self.im_div = 5
else:
self.im_div = 1
if data_split == "train":
self.transform = transforms.Compose([
# transforms.Resize((278, 278)),
transforms.Resize((256, 256)),
transforms.RandomRotation(degrees=(0, 90)),
# transforms.RandomCrop(256),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
else:
self.transform = transforms.Compose([
# transforms.Resize((256, 256)),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
def __getitem__(self, index):
# handle the image redundancy
img_id = index//self.im_div
caption = self.captions[index]
vocab = self.vocab
# Convert caption (string) to word ids.
tokens = nltk.tokenize.word_tokenize(
caption.lower().decode('utf-8'))
punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%']
tokens = [k for k in tokens if k not in punctuations]
tokens_UNK = [k if k in vocab.word2idx.keys() else '<unk>' for k in tokens]
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(token) for token in tokens_UNK])
caption.append(vocab('<end>'))
target = torch.LongTensor(caption)
image = Image.open(self.img_path + str(self.images[img_id])[2:-1]).convert('RGB')
image = self.transform(image) # torch.Size([3, 256, 256])
return image, target, index, img_id
def __len__(self):
return self.length
def collate_fn(data):
# Sort a data list by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions, ids, img_ids = zip(*data)
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
lengths = [l if l !=0 else 1 for l in lengths]
return images, targets, lengths, ids
def get_precomp_loader(data_split, vocab, batch_size=100,
shuffle=True, num_workers=0, opt={}):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
dset = PrecompDataset(data_split, vocab, opt)
data_loader = torch.utils.data.DataLoader(dataset=dset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
collate_fn=collate_fn,
num_workers=num_workers)
return data_loader
def get_loaders(vocab, opt):
train_loader = get_precomp_loader( 'train', vocab,
opt['dataset']['batch_size'], True, opt['dataset']['workers'], opt=opt)
val_loader = get_precomp_loader( 'val', vocab,
opt['dataset']['batch_size_val'], False, opt['dataset']['workers'], opt=opt)
return train_loader, val_loader
def get_test_loader(vocab, opt):
test_loader = get_precomp_loader( 'test', vocab,
opt['dataset']['batch_size_val'], False, opt['dataset']['workers'], opt=opt)
return test_loader