Graduation_Project/QN/RecipeRetrieval/dataset/transforms.py

218 lines
8.6 KiB
Python

import collections
import torch
class Compose(object):
"""Composes several collate together.
Args:
transforms (list of ``Collate`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, batch):
for transform in self.transforms:
batch = transform(batch)
return batch
class ListDictsToDictLists(object):
def __init__(self):
pass
def __call__(self, batch):
batch = self.ld_to_dl(batch)
return batch
def ld_to_dl(self, batch):
"""
Convert a list of dictionaries into a dictionary of lists.
Args:
batch (list): A list of dictionaries where each dictionary represents a data sample.
Returns:
dict: A dictionary where each key corresponds to a key in the original dictionaries,
and each value is a list containing the values associated with that key across
all dictionaries in the batch.
"""
if isinstance(batch[0], collections.Mapping):
# Check if the first element of the batch (presumably) is a dictionary.
# If it is, it assumes that all elements of the batch are dictionaries with the same keys.
# --------------------------------------------------------------------------------------
# Why use collections.Mapping instead of dict?
# Compatibility with Different Dictionary Implementations:
# While dict specifically checks if the object is an instance of the built-in dict type,
# collections.Mapping is an abstract base class for mappings in Python,
# including not just dictionaries (dict), but also other types of mappings such as collections.OrderedDict,
# collections.defaultdict, and potentially custom mapping types.
# This means that collections.Mapping will correctly identify objects that behave like dictionaries
# but are not necessarily instances of the built-in dict type.
return {key: self.ld_to_dl([d[key] for d in batch]) for key in batch[0]}
else:
return batch
class PadTensors(object):
# Pads tensor objects within a batch to a uniform size, ensuring that all tensors in the batch have the
# same shape. This class is essential for dealing with variable-length sequences in batched inputs, such
# as text or time series data, allowing them to be processed by fixed-size neural network layers. The
# padding value, along with specific keys to include or exclude for padding, can be customized.
def __init__(self, value=0, use_keys=None, avoid_keys=None):
self.value = value
self.use_keys = use_keys or []
self.avoid_keys = avoid_keys or []
def __call__(self, batch):
batch = self.pad_tensors(batch)
return batch
def pad_tensors(self, batch):
if isinstance(batch, collections.Mapping):
out = {}
for key, value in batch.items():
if (key in self.use_keys) or \
(len(self.use_keys) == 0 and key not in self.avoid_keys):
out[key] = self.pad_tensors(value)
else:
out[key] = value
return out
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
max_size = [max(item.size(i) for item in batch) for i in range(batch[0].dim())]
max_size = torch.Size(max_size)
n_batch = []
for item in batch:
if item.size() != max_size:
n_item = item.new(max_size).fill_(self.value)
# TODO: Improve this
if item.dim() == 1:
n_item[:item.size(0)] = item
elif item.dim() == 2:
n_item[:item.size(0), :item.size(1)] = item
elif item.dim() == 3:
n_item[:item.size(0), :item.size(1), :item.size(2)] = item
else:
raise ValueError
n_batch.append(n_item)
else:
n_batch.append(item)
return n_batch
else:
return batch
class StackTensors(object):
# Stacks tensors along a new dimension (usually the batch dimension), creating a single tensor from a
# list of tensors. This operation is key for creating batched inputs from individual samples. The class
# optionally supports stacking in shared memory, reducing memory overhead when working with large data
# or in multi-process settings.
def __init__(self, use_shared_memory=False, avoid_keys=None):
self.use_shared_memory = use_shared_memory
self.avoid_keys = avoid_keys or []
def __call__(self, batch):
batch = self.stack_tensors(batch)
return batch
# key argument is useful for debuging
def stack_tensors(self, batch, key=None):
if isinstance(batch, collections.Mapping):
out = {}
for key, value in batch.items():
if key not in self.avoid_keys:
out[key] = self.stack_tensors(value, key=key)
else:
out[key] = value
return out
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
out = None
if self.use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
else:
return batch
class CatTensors(object):
# Concatenates tensors along an existing dimension, as opposed to stacking them along a new one. This
# class supports selective concatenation based on specified keys and can perform the operation in shared
# memory to optimize multi-process data loading and transformation pipelines. Additionally, it can
# automatically handle the creation of batch indices for concatenated tensors.
def __init__(self, use_shared_memory=False, use_keys=None, avoid_keys=None):
self.use_shared_memory = use_shared_memory
self.use_keys = use_keys or []
self.avoid_keys = avoid_keys or []
def __call__(self, batch):
batch = self.cat_tensors(batch)
return batch
def cat_tensors(self, batch):
if isinstance(batch, collections.Mapping):
out = {}
for key, value in batch.items():
if (key in self.use_keys) or \
(len(self.use_keys) == 0 and key not in self.avoid_keys):
out[key] = self.cat_tensors(value)
if ('batch_id' not in out) and torch.is_tensor(value[0]):
out['batch_id'] = torch.cat(
[i * torch.ones(x.size(0)) for i, x in enumerate(value)], 0)
else:
out[key] = value
return out
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
out = None
if self.use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.cat(batch, 0, out=out)
else:
return batch
class SortByKey(object):
def __init__(self, key='lengths', reverse=True):
self.key = key
self.reverse = True
self.i = 0
def __call__(self, batch):
self.set_sort_keys(batch[self.key]) # must be a list
batch = self.sort_by_key(batch)
return batch
def set_sort_keys(self, sort_keys):
self.i = 0
self.sort_keys = sort_keys
# ugly hack to be able to sort without lambda function
def get_key(self, _):
key = self.sort_keys[self.i]
self.i += 1
if self.i >= len(self.sort_keys):
self.i = 0
return key
def sort_by_key(self, batch):
if isinstance(batch, collections.Mapping):
return {key: self.sort_by_key(value) for key, value in batch.items()}
elif type(batch) is list: # isinstance(batch, collections.Sequence):
return sorted(batch, key=self.get_key, reverse=self.reverse)
else:
return batch