Graduation_Project/LHL/image_caption.py

372 lines
14 KiB
Python
Raw Normal View History

2024-06-25 11:50:04 +08:00
"""COCO dataset loader"""
import torch
import torch.utils.data as data
import os
import os.path as osp
import numpy as np
from imageio import imread
import random
import json
import cv2
import logging
logger = logging.getLogger(__name__)
class RawImageDataset(data.Dataset):
"""
Load precomputed captions and image features
Possible options: f30k_precomp, coco_precomp
"""
def __init__(self, data_path, data_name, data_split, tokenzier, opt, train):
self.opt = opt
self.train = train
self.data_path = data_path
self.data_name = data_name
self.tokenizer = tokenzier
loc_cap = osp.join(data_path, 'precomp')
loc_image = osp.join(data_path, 'precomp')
loc_mapping = osp.join(data_path, 'id_mapping.json')
if 'coco' in data_name:
self.image_base = osp.join(data_path, 'images')
else:
self.image_base = osp.join(data_path, 'flickr30k-images')
with open(loc_mapping, 'r') as f_mapping:
self.id_to_path = json.load(f_mapping)
# Read Captions
self.captions = []
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
for line in f:
self.captions.append(line.strip())
# Get the image ids
with open(osp.join(loc_image, '{}_ids.txt'.format(data_split)), 'r') as f:
image_ids = f.readlines()
self.images = [int(x.strip()) for x in image_ids]
# Set related parameters according to the pre-trained backbone **
assert 'backbone' in opt.precomp_enc_type
self.backbone_source = opt.backbone_source
self.base_target_size = 256
self.crop_ratio = 0.875
self.train_scale_rate = 1
if hasattr(opt, 'input_scale_factor') and opt.input_scale_factor != 1:
self.base_target_size = int(self.base_target_size * opt.input_scale_factor)
logger.info('Input images are scaled by factor {}'.format(opt.input_scale_factor))
if 'detector' in self.backbone_source:
self.pixel_means = np.array([[[102.9801, 115.9465, 122.7717]]])
else:
self.imagenet_mean = [0.485, 0.456, 0.406]
self.imagenet_std = [0.229, 0.224, 0.225]
self.length = len(self.captions)
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
num_images = len(self.images)
if num_images != self.length:
self.im_div = 5
else:
self.im_div = 1
# the development set for coco is large and so validation would be slow
if data_split == 'dev':
self.length = 5000
def __getitem__(self, index):
img_index = index // self.im_div
caption = self.captions[index]
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
# Convert caption (string) to word ids (with Size Augmentation at training time).
target = process_caption(self.tokenizer, caption_tokens, self.train)
image_id = self.images[img_index]
image_path = os.path.join(self.image_base, self.id_to_path[str(image_id)])
im_in = np.array(imread(image_path))
processed_image = self._process_image(im_in)
image = torch.Tensor(processed_image)
image = image.permute(2, 0, 1)
return image, target, index, img_index
def __len__(self):
return self.length
def _process_image(self, im_in):
"""
Converts an image into a network input, with pre-processing including re-scaling, padding, etc, and data
augmentation.
"""
if len(im_in.shape) == 2:
im_in = im_in[:, :, np.newaxis]
im_in = np.concatenate((im_in, im_in, im_in), axis=2)
if 'detector' in self.backbone_source:
im_in = im_in[:, :, ::-1]
im = im_in.astype(np.float32, copy=True)
if self.train:
target_size = self.base_target_size * self.train_scale_rate
else:
target_size = self.base_target_size
# 2. Random crop when in training mode, elsewise just skip
if self.train:
crop_ratio = np.random.random() * 0.4 + 0.6
crop_size_h = int(im.shape[0] * crop_ratio)
crop_size_w = int(im.shape[1] * crop_ratio)
processed_im = self._crop(im, crop_size_h, crop_size_w, random=True)
else:
processed_im = im
# 3. Resize to the target resolution
im_shape = processed_im.shape
im_scale_x = float(target_size) / im_shape[1]
im_scale_y = float(target_size) / im_shape[0]
processed_im = cv2.resize(processed_im, None, None, fx=im_scale_x, fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
if self.train:
if np.random.random() > 0.5:
processed_im = self._hori_flip(processed_im)
# Normalization
if 'detector' in self.backbone_source:
processed_im = self._detector_norm(processed_im)
else:
processed_im = self._imagenet_norm(processed_im)
return processed_im
def _imagenet_norm(self, im_in):
im_in = im_in.astype(np.float32)
im_in = im_in / 255
for i in range(im_in.shape[-1]):
im_in[:, :, i] = (im_in[:, :, i] - self.imagenet_mean[i]) / self.imagenet_std[i]
return im_in
def _detector_norm(self, im_in):
im_in = im_in.astype(np.float32)
im_in -= self.pixel_means
return im_in
@staticmethod
def _crop(im, crop_size_h, crop_size_w, random):
h, w = im.shape[0], im.shape[1]
if random:
if w - crop_size_w == 0:
x_start = 0
else:
x_start = np.random.randint(w - crop_size_w, size=1)[0]
if h - crop_size_h == 0:
y_start = 0
else:
y_start = np.random.randint(h - crop_size_h, size=1)[0]
else:
x_start = (w - crop_size_w) // 2
y_start = (h - crop_size_h) // 2
cropped_im = im[y_start:y_start + crop_size_h, x_start:x_start + crop_size_w, :]
return cropped_im
@staticmethod
def _hori_flip(im):
im = np.fliplr(im).copy()
return im
class PrecompRegionDataset(data.Dataset):
"""
Load precomputed captions and image features for COCO or Flickr
"""
def __init__(self, data_path, data_name, data_split, tokenizer, opt, train):
self.tokenizer = tokenizer
self.opt = opt
self.train = train
self.data_path = data_path
self.data_name = data_name
# loc_cap = osp.join(data_path, 'precomp')
# loc_image = osp.join(data_path, 'precomp')
loc_cap = data_path
loc_image = data_path
# Captions
self.captions = []
with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'r') as f:
for line in f:
self.captions.append(line.strip())
# Image features
self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split), mmap_mode = 'r')
self.length = len(self.captions)
# rkiros data has redundancy in images, we divide by 5, 10crop doesn't
num_images = len(self.images)
if num_images != self.length:
self.im_div = 5
else:
self.im_div = 1
# the development set for coco is large and so validation would be slow
if data_split == 'dev':
self.length = 5000
# if data_split == 'test':
# self.length = 5000
def __getitem__(self, index):
# handle the image redundancy
img_index = index // self.im_div
caption = self.captions[index]
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
# Convert caption (string) to word ids (with Size Augmentation at training time)
target = process_caption(self.tokenizer, caption_tokens, self.train)
image = self.images[img_index]
if self.train: # Size augmentation for region feature
num_features = image.shape[0]
rand_list = np.random.rand(num_features)
image = image[np.where(rand_list > 0.20)]
image = torch.Tensor(image)
return image, target, index, img_index
def __len__(self):
return self.length
def process_caption(tokenizer, tokens, train=True):
output_tokens = []
deleted_idx = []
for i, token in enumerate(tokens):
sub_tokens = tokenizer.wordpiece_tokenizer.tokenize(token)
prob = random.random()
if prob < 0.20 and train: # mask/remove the tokens only during training
prob /= 0.20
# 50% randomly change token to mask token
if prob < 0.5:
for sub_token in sub_tokens:
output_tokens.append("[MASK]")
# 10% randomly change token to random token
elif prob < 0.6:
for sub_token in sub_tokens:
output_tokens.append(random.choice(list(tokenizer.vocab.keys())))
# -> rest 10% randomly keep current token
else:
for sub_token in sub_tokens:
output_tokens.append(sub_token)
deleted_idx.append(len(output_tokens) - 1)
else:
for sub_token in sub_tokens:
# no masking token (will be ignored by loss function later)
output_tokens.append(sub_token)
if len(deleted_idx) != 0:
output_tokens = [output_tokens[i] for i in range(len(output_tokens)) if i not in deleted_idx]
output_tokens = ['[CLS]'] + output_tokens + ['[SEP]']
target = tokenizer.convert_tokens_to_ids(output_tokens)
target = torch.Tensor(target)
return target
def collate_fn(data):
"""Build mini-batch tensors from a list of (image, caption) tuples.
Args:
data: list of (image, caption) tuple.
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
images, captions, ids, img_ids = zip(*data)
if len(images[0].shape) == 2: # region feature
# Sort a data list by caption length
# Merge images (convert tuple of 3D tensor to 4D tensor)
# images = torch.stack(images, 0)
img_lengths = [len(image) for image in images]
all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1))
for i, image in enumerate(images):
end = img_lengths[i]
all_images[i, :end] = image[:end]
img_lengths = torch.Tensor(img_lengths)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return all_images, img_lengths, targets, lengths, ids
else: # raw input image
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths, ids
def get_loader(data_path, data_name, data_split, tokenizer, opt, batch_size=100,
shuffle=True, num_workers=0, train=True):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
if train:
drop_last = True
else:
drop_last = False
if opt.precomp_enc_type == 'basic':
dset = PrecompRegionDataset(data_path, data_name, data_split, tokenizer, opt, train)
data_loader = torch.utils.data.DataLoader(dataset=dset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
collate_fn=collate_fn,
num_workers=num_workers,
drop_last=drop_last)
else:
dset = RawImageDataset(data_path, data_name, data_split, tokenizer, opt, train)
data_loader = torch.utils.data.DataLoader(dataset=dset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
collate_fn=collate_fn)
return data_loader
def get_loaders(data_path, data_name, tokenizer, batch_size, workers, opt):
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
batch_size, True, workers)
val_loader = get_loader(data_path, data_name, 'dev', tokenizer, opt,
batch_size, False, workers, train=False)
return train_loader, val_loader
def get_train_loader(data_path, data_name, tokenizer, batch_size, workers, opt, shuffle):
train_loader = get_loader(data_path, data_name, 'train', tokenizer, opt,
batch_size, shuffle, workers)
return train_loader
def get_test_loader(split_name, data_name, tokenizer, batch_size, workers, opt):
test_loader = get_loader(opt.data_path, data_name, split_name, tokenizer, opt,
batch_size, False, workers, train=False)
return test_loader