import collections import torch class Compose(object): """Composes several collate together. Args: transforms (list of ``Collate`` objects): list of transforms to compose. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, batch): for transform in self.transforms: batch = transform(batch) return batch class ListDictsToDictLists(object): def __init__(self): pass def __call__(self, batch): batch = self.ld_to_dl(batch) return batch def ld_to_dl(self, batch): """ Convert a list of dictionaries into a dictionary of lists. Args: batch (list): A list of dictionaries where each dictionary represents a data sample. Returns: dict: A dictionary where each key corresponds to a key in the original dictionaries, and each value is a list containing the values associated with that key across all dictionaries in the batch. """ if isinstance(batch[0], collections.Mapping): # Check if the first element of the batch (presumably) is a dictionary. # If it is, it assumes that all elements of the batch are dictionaries with the same keys. # -------------------------------------------------------------------------------------- # Why use collections.Mapping instead of dict? # Compatibility with Different Dictionary Implementations: # While dict specifically checks if the object is an instance of the built-in dict type, # collections.Mapping is an abstract base class for mappings in Python, # including not just dictionaries (dict), but also other types of mappings such as collections.OrderedDict, # collections.defaultdict, and potentially custom mapping types. # This means that collections.Mapping will correctly identify objects that behave like dictionaries # but are not necessarily instances of the built-in dict type. return {key: self.ld_to_dl([d[key] for d in batch]) for key in batch[0]} else: return batch class PadTensors(object): # Pads tensor objects within a batch to a uniform size, ensuring that all tensors in the batch have the # same shape. This class is essential for dealing with variable-length sequences in batched inputs, such # as text or time series data, allowing them to be processed by fixed-size neural network layers. The # padding value, along with specific keys to include or exclude for padding, can be customized. def __init__(self, value=0, use_keys=None, avoid_keys=None): self.value = value self.use_keys = use_keys or [] self.avoid_keys = avoid_keys or [] def __call__(self, batch): batch = self.pad_tensors(batch) return batch def pad_tensors(self, batch): if isinstance(batch, collections.Mapping): out = {} for key, value in batch.items(): if (key in self.use_keys) or \ (len(self.use_keys) == 0 and key not in self.avoid_keys): out[key] = self.pad_tensors(value) else: out[key] = value return out elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): max_size = [max(item.size(i) for item in batch) for i in range(batch[0].dim())] max_size = torch.Size(max_size) n_batch = [] for item in batch: if item.size() != max_size: n_item = item.new(max_size).fill_(self.value) # TODO: Improve this if item.dim() == 1: n_item[:item.size(0)] = item elif item.dim() == 2: n_item[:item.size(0), :item.size(1)] = item elif item.dim() == 3: n_item[:item.size(0), :item.size(1), :item.size(2)] = item else: raise ValueError n_batch.append(n_item) else: n_batch.append(item) return n_batch else: return batch class StackTensors(object): # Stacks tensors along a new dimension (usually the batch dimension), creating a single tensor from a # list of tensors. This operation is key for creating batched inputs from individual samples. The class # optionally supports stacking in shared memory, reducing memory overhead when working with large data # or in multi-process settings. def __init__(self, use_shared_memory=False, avoid_keys=None): self.use_shared_memory = use_shared_memory self.avoid_keys = avoid_keys or [] def __call__(self, batch): batch = self.stack_tensors(batch) return batch # key argument is useful for debuging def stack_tensors(self, batch, key=None): if isinstance(batch, collections.Mapping): out = {} for key, value in batch.items(): if key not in self.avoid_keys: out[key] = self.stack_tensors(value, key=key) else: out[key] = value return out elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): out = None if self.use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) else: return batch class CatTensors(object): # Concatenates tensors along an existing dimension, as opposed to stacking them along a new one. This # class supports selective concatenation based on specified keys and can perform the operation in shared # memory to optimize multi-process data loading and transformation pipelines. Additionally, it can # automatically handle the creation of batch indices for concatenated tensors. def __init__(self, use_shared_memory=False, use_keys=None, avoid_keys=None): self.use_shared_memory = use_shared_memory self.use_keys = use_keys or [] self.avoid_keys = avoid_keys or [] def __call__(self, batch): batch = self.cat_tensors(batch) return batch def cat_tensors(self, batch): if isinstance(batch, collections.Mapping): out = {} for key, value in batch.items(): if (key in self.use_keys) or \ (len(self.use_keys) == 0 and key not in self.avoid_keys): out[key] = self.cat_tensors(value) if ('batch_id' not in out) and torch.is_tensor(value[0]): out['batch_id'] = torch.cat( [i * torch.ones(x.size(0)) for i, x in enumerate(value)], 0) else: out[key] = value return out elif isinstance(batch, collections.Sequence) and torch.is_tensor(batch[0]): out = None if self.use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.cat(batch, 0, out=out) else: return batch class SortByKey(object): def __init__(self, key='lengths', reverse=True): self.key = key self.reverse = True self.i = 0 def __call__(self, batch): self.set_sort_keys(batch[self.key]) # must be a list batch = self.sort_by_key(batch) return batch def set_sort_keys(self, sort_keys): self.i = 0 self.sort_keys = sort_keys # ugly hack to be able to sort without lambda function def get_key(self, _): key = self.sort_keys[self.i] self.i += 1 if self.i >= len(self.sort_keys): self.i = 0 return key def sort_by_key(self, batch): if isinstance(batch, collections.Mapping): return {key: self.sort_by_key(value) for key, value in batch.items()} elif type(batch) is list: # isinstance(batch, collections.Sequence): return sorted(batch, key=self.get_key, reverse=self.reverse) else: return batch