import time import torch batch_size = 8 in_channels = 32 patch_h = 2 patch_w = 2 num_patch_h = 16 num_patch_w = 16 num_patches = num_patch_h * num_patch_w patch_area = patch_h * patch_w def official(x: torch.Tensor): # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w] x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w) # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w] x = x.transpose(1, 2) # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w x = x.reshape(batch_size, in_channels, num_patches, patch_area) # [B, C, N, P] -> [B, P, N, C] x = x.transpose(1, 3) # [B, P, N, C] -> [BP, N, C] x = x.reshape(batch_size * patch_area, num_patches, -1) return x def my_self(x: torch.Tensor): # [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w] x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w) # [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w] x = x.transpose(3, 4) # [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w x = x.reshape(batch_size, in_channels, num_patches, patch_area) # [B, C, N, P] -> [B, P, N, C] x = x.transpose(1, 3) # [B, P, N, C] -> [BP, N, C] x = x.reshape(batch_size * patch_area, num_patches, -1) return x if __name__ == '__main__': t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w) print(torch.equal(official(t), my_self(t))) t1 = time.time() for _ in range(1000): official(t) print(f"official time: {time.time() - t1}") t1 = time.time() for _ in range(1000): my_self(t) print(f"self time: {time.time() - t1}")