import torch import torch.nn as nn class Normalize(nn.Module): def __init__(self, mean, std): super(Normalize, self).__init__() self.mean = mean self.std = std def forward(self, input): size = input.size() x = input.clone() for i in range(size[1]): x[:, i] = (x[:, i] - self.mean[i]) / self.std[i] return x class TfNormalize(nn.Module): def __init__(self, mean=0, std=1, mode='tensorflow'): super(TfNormalize, self).__init__() self.mean = mean self.std = std self.mode = mode def forward(self, input): size = input.size() x = input.clone() if self.mode == 'tensorflow': x = x * 2.0 - 1.0 elif self.mode == 'torch': for i in range(size[1]): x[:, i] = (x[:, i] - self.mean[i]) / self.std[i] return x class Permute(nn.Module): def __init__(self, permutation=[2, 1, 0]): super().__init__() self.permutation = permutation def forward(self, input): return input[:, self.permutation]