218 lines
8.6 KiB
Python
218 lines
8.6 KiB
Python
import collections
|
|
import torch
|
|
|
|
|
|
class Compose(object):
|
|
"""Composes several collate together.
|
|
|
|
Args:
|
|
transforms (list of ``Collate`` objects): list of transforms to compose.
|
|
"""
|
|
|
|
def __init__(self, transforms):
|
|
self.transforms = transforms
|
|
|
|
def __call__(self, batch):
|
|
for transform in self.transforms:
|
|
batch = transform(batch)
|
|
return batch
|
|
|
|
|
|
class ListDictsToDictLists(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, batch):
|
|
batch = self.ld_to_dl(batch)
|
|
return batch
|
|
|
|
def ld_to_dl(self, batch):
|
|
"""
|
|
Convert a list of dictionaries into a dictionary of lists.
|
|
|
|
Args:
|
|
batch (list): A list of dictionaries where each dictionary represents a data sample.
|
|
|
|
Returns:
|
|
dict: A dictionary where each key corresponds to a key in the original dictionaries,
|
|
and each value is a list containing the values associated with that key across
|
|
all dictionaries in the batch.
|
|
"""
|
|
if isinstance(batch[0], collections.Mapping):
|
|
# Check if the first element of the batch (presumably) is a dictionary.
|
|
# If it is, it assumes that all elements of the batch are dictionaries with the same keys.
|
|
# --------------------------------------------------------------------------------------
|
|
# Why use collections.Mapping instead of dict?
|
|
# Compatibility with Different Dictionary Implementations:
|
|
# While dict specifically checks if the object is an instance of the built-in dict type,
|
|
# collections.Mapping is an abstract base class for mappings in Python,
|
|
# including not just dictionaries (dict), but also other types of mappings such as collections.OrderedDict,
|
|
# collections.defaultdict, and potentially custom mapping types.
|
|
# This means that collections.Mapping will correctly identify objects that behave like dictionaries
|
|
# but are not necessarily instances of the built-in dict type.
|
|
return {key: self.ld_to_dl([d[key] for d in batch]) for key in batch[0]}
|
|
else:
|
|
return batch
|
|
|
|
|
|
class PadTensors(object):
|
|
# Pads tensor objects within a batch to a uniform size, ensuring that all tensors in the batch have the
|
|
# same shape. This class is essential for dealing with variable-length sequences in batched inputs, such
|
|
# as text or time series data, allowing them to be processed by fixed-size neural network layers. The
|
|
# padding value, along with specific keys to include or exclude for padding, can be customized.
|
|
|
|
def __init__(self, value=0, use_keys=None, avoid_keys=None):
|
|
self.value = value
|
|
self.use_keys = use_keys or []
|
|
self.avoid_keys = avoid_keys or []
|
|
|
|
def __call__(self, batch):
|
|
batch = self.pad_tensors(batch)
|
|
return batch
|
|
|
|
def pad_tensors(self, batch):
|
|
if isinstance(batch, collections.Mapping):
|
|
out = {}
|
|
for key, value in batch.items():
|
|
if (key in self.use_keys) or \
|
|
(len(self.use_keys) == 0 and key not in self.avoid_keys):
|
|
out[key] = self.pad_tensors(value)
|
|
else:
|
|
out[key] = value
|
|
return out
|
|
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
|
|
max_size = [max(item.size(i) for item in batch) for i in range(batch[0].dim())]
|
|
max_size = torch.Size(max_size)
|
|
n_batch = []
|
|
for item in batch:
|
|
if item.size() != max_size:
|
|
n_item = item.new(max_size).fill_(self.value)
|
|
# TODO: Improve this
|
|
if item.dim() == 1:
|
|
n_item[:item.size(0)] = item
|
|
elif item.dim() == 2:
|
|
n_item[:item.size(0), :item.size(1)] = item
|
|
elif item.dim() == 3:
|
|
n_item[:item.size(0), :item.size(1), :item.size(2)] = item
|
|
else:
|
|
raise ValueError
|
|
n_batch.append(n_item)
|
|
else:
|
|
n_batch.append(item)
|
|
return n_batch
|
|
else:
|
|
return batch
|
|
|
|
|
|
class StackTensors(object):
|
|
# Stacks tensors along a new dimension (usually the batch dimension), creating a single tensor from a
|
|
# list of tensors. This operation is key for creating batched inputs from individual samples. The class
|
|
# optionally supports stacking in shared memory, reducing memory overhead when working with large data
|
|
# or in multi-process settings.
|
|
|
|
|
|
def __init__(self, use_shared_memory=False, avoid_keys=None):
|
|
self.use_shared_memory = use_shared_memory
|
|
self.avoid_keys = avoid_keys or []
|
|
|
|
def __call__(self, batch):
|
|
batch = self.stack_tensors(batch)
|
|
return batch
|
|
|
|
# key argument is useful for debuging
|
|
def stack_tensors(self, batch, key=None):
|
|
if isinstance(batch, collections.Mapping):
|
|
out = {}
|
|
for key, value in batch.items():
|
|
if key not in self.avoid_keys:
|
|
out[key] = self.stack_tensors(value, key=key)
|
|
else:
|
|
out[key] = value
|
|
return out
|
|
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
|
|
out = None
|
|
if self.use_shared_memory:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum(x.numel() for x in batch)
|
|
storage = batch[0].storage()._new_shared(numel)
|
|
out = batch[0].new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
else:
|
|
return batch
|
|
|
|
|
|
class CatTensors(object):
|
|
# Concatenates tensors along an existing dimension, as opposed to stacking them along a new one. This
|
|
# class supports selective concatenation based on specified keys and can perform the operation in shared
|
|
# memory to optimize multi-process data loading and transformation pipelines. Additionally, it can
|
|
# automatically handle the creation of batch indices for concatenated tensors.
|
|
|
|
|
|
def __init__(self, use_shared_memory=False, use_keys=None, avoid_keys=None):
|
|
self.use_shared_memory = use_shared_memory
|
|
self.use_keys = use_keys or []
|
|
self.avoid_keys = avoid_keys or []
|
|
|
|
def __call__(self, batch):
|
|
batch = self.cat_tensors(batch)
|
|
return batch
|
|
|
|
def cat_tensors(self, batch):
|
|
if isinstance(batch, collections.Mapping):
|
|
out = {}
|
|
for key, value in batch.items():
|
|
if (key in self.use_keys) or \
|
|
(len(self.use_keys) == 0 and key not in self.avoid_keys):
|
|
out[key] = self.cat_tensors(value)
|
|
if ('batch_id' not in out) and torch.is_tensor(value[0]):
|
|
out['batch_id'] = torch.cat(
|
|
[i * torch.ones(x.size(0)) for i, x in enumerate(value)], 0)
|
|
else:
|
|
out[key] = value
|
|
return out
|
|
elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]):
|
|
out = None
|
|
if self.use_shared_memory:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum(x.numel() for x in batch)
|
|
storage = batch[0].storage()._new_shared(numel)
|
|
out = batch[0].new(storage)
|
|
return torch.cat(batch, 0, out=out)
|
|
else:
|
|
return batch
|
|
|
|
|
|
class SortByKey(object):
|
|
|
|
def __init__(self, key='lengths', reverse=True):
|
|
self.key = key
|
|
self.reverse = True
|
|
self.i = 0
|
|
|
|
def __call__(self, batch):
|
|
self.set_sort_keys(batch[self.key]) # must be a list
|
|
batch = self.sort_by_key(batch)
|
|
return batch
|
|
|
|
def set_sort_keys(self, sort_keys):
|
|
self.i = 0
|
|
self.sort_keys = sort_keys
|
|
|
|
# ugly hack to be able to sort without lambda function
|
|
def get_key(self, _):
|
|
key = self.sort_keys[self.i]
|
|
self.i += 1
|
|
if self.i >= len(self.sort_keys):
|
|
self.i = 0
|
|
return key
|
|
|
|
def sort_by_key(self, batch):
|
|
if isinstance(batch, collections.Mapping):
|
|
return {key: self.sort_by_key(value) for key, value in batch.items()}
|
|
elif type(batch) is list: # isinstance(batch, collections.Sequence):
|
|
return sorted(batch, key=self.get_key, reverse=self.reverse)
|
|
else:
|
|
return batch |