""" original code from apple: https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py """ from typing import Optional, Tuple, Union, Dict import math import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from transformer import TransformerEncoder from model_config import get_config def make_divisible( v: Union[float, int], divisor: Optional[int] = 8, min_value: Optional[Union[float, int]] = None, ) -> Union[float, int]: """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_value: :return: """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v class ConvLayer(nn.Module): """ Applies a 2D convolution over an input Args: in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 groups (Optional[int]): Number of groups in convolution. Default: 1 bias (Optional[bool]): Use bias. Default: ``False`` use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). Default: ``True`` Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` .. note:: For depth-wise convolution, `groups=C_{in}=C_{out}`. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = 1, groups: Optional[int] = 1, bias: Optional[bool] = False, use_norm: Optional[bool] = True, use_act: Optional[bool] = True, ) -> None: super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) assert isinstance(kernel_size, Tuple) assert isinstance(stride, Tuple) padding = ( int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2), ) block = nn.Sequential() conv_layer = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, groups=groups, padding=padding, bias=bias ) block.add_module(name="conv", module=conv_layer) if use_norm: norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1) block.add_module(name="norm", module=norm_layer) if use_act: act_layer = nn.SiLU() block.add_module(name="act", module=act_layer) self.block = block def forward(self, x: Tensor) -> Tensor: return self.block(x) class InvertedResidual(nn.Module): """ This class implements the inverted residual block, as described in `MobileNetv2 `_ paper Args: in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)` stride (int): Use convolutions with a stride. Default: 1 expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv skip_connection (Optional[bool]): Use skip-connection. Default: True Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` .. note:: If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False` """ def __init__( self, in_channels: int, out_channels: int, stride: int, expand_ratio: Union[int, float], skip_connection: Optional[bool] = True, ) -> None: assert stride in [1, 2] hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8) super().__init__() block = nn.Sequential() if expand_ratio != 1: block.add_module( name="exp_1x1", module=ConvLayer( in_channels=in_channels, out_channels=hidden_dim, kernel_size=1 ), ) block.add_module( name="conv_3x3", module=ConvLayer( in_channels=hidden_dim, out_channels=hidden_dim, stride=stride, kernel_size=3, groups=hidden_dim ), ) block.add_module( name="red_1x1", module=ConvLayer( in_channels=hidden_dim, out_channels=out_channels, kernel_size=1, use_act=False, use_norm=True, ), ) self.block = block self.in_channels = in_channels self.out_channels = out_channels self.exp = expand_ratio self.stride = stride self.use_res_connect = ( self.stride == 1 and in_channels == out_channels and skip_connection ) def forward(self, x: Tensor, *args, **kwargs) -> Tensor: if self.use_res_connect: return x + self.block(x) else: return self.block(x) class MobileViTBlock(nn.Module): """ This class defines the `MobileViT block `_ Args: opts: command line arguments in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` transformer_dim (int): Input dimension to the transformer unit ffn_dim (int): Dimension of the FFN block n_transformer_blocks (int): Number of transformer blocks. Default: 2 head_dim (int): Head dimension in the multi-head attention. Default: 32 attn_dropout (float): Dropout in multi-head attention. Default: 0.0 dropout (float): Dropout rate. Default: 0.0 ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0 patch_h (int): Patch height for unfolding operation. Default: 8 patch_w (int): Patch width for unfolding operation. Default: 8 transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3 no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False """ def __init__( self, in_channels: int, transformer_dim: int, ffn_dim: int, n_transformer_blocks: int = 2, head_dim: int = 32, attn_dropout: float = 0.0, dropout: float = 0.0, ffn_dropout: float = 0.0, patch_h: int = 8, patch_w: int = 8, conv_ksize: Optional[int] = 3, *args, **kwargs ) -> None: super().__init__() conv_3x3_in = ConvLayer( in_channels=in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1 ) conv_1x1_in = ConvLayer( in_channels=in_channels, out_channels=transformer_dim, kernel_size=1, stride=1, use_norm=False, use_act=False ) conv_1x1_out = ConvLayer( in_channels=transformer_dim, out_channels=in_channels, kernel_size=1, stride=1 ) conv_3x3_out = ConvLayer( in_channels=2 * in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1 ) self.local_rep = nn.Sequential() self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in) self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in) assert transformer_dim % head_dim == 0 num_heads = transformer_dim // head_dim global_rep = [ TransformerEncoder( embed_dim=transformer_dim, ffn_latent_dim=ffn_dim, num_heads=num_heads, attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout ) for _ in range(n_transformer_blocks) ] global_rep.append(nn.LayerNorm(transformer_dim)) self.global_rep = nn.Sequential(*global_rep) self.conv_proj = conv_1x1_out self.fusion = conv_3x3_out self.patch_h = patch_h self.patch_w = patch_w self.patch_area = self.patch_w * self.patch_h self.cnn_in_dim = in_channels self.cnn_out_dim = transformer_dim self.n_heads = num_heads self.ffn_dim = ffn_dim self.dropout = dropout self.attn_dropout = attn_dropout self.ffn_dropout = ffn_dropout self.n_blocks = n_transformer_blocks self.conv_ksize = conv_ksize def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]: patch_w, patch_h = self.patch_w, self.patch_h patch_area = patch_w * patch_h batch_size, in_channels, orig_h, orig_w = x.shape new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) interpolate = False if new_w != orig_w or new_h != orig_h: # Note: Padding can be done, but then it needs to be handled in attention function. x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False) interpolate = True # number of patches along width and height num_patch_w = new_w // patch_w # n_w num_patch_h = new_h // patch_h # n_h num_patches = num_patch_h * num_patch_w # N # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w] x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w) # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w] x = x.transpose(1, 2) # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w x = x.reshape(batch_size, in_channels, num_patches, patch_area) # [B, C, N, P] -> [B, P, N, C] x = x.transpose(1, 3) # [B, P, N, C] -> [BP, N, C] x = x.reshape(batch_size * patch_area, num_patches, -1) info_dict = { "orig_size": (orig_h, orig_w), "batch_size": batch_size, "interpolate": interpolate, "total_patches": num_patches, "num_patches_w": num_patch_w, "num_patches_h": num_patch_h, } return x, info_dict def folding(self, x: Tensor, info_dict: Dict) -> Tensor: n_dim = x.dim() assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format( x.shape ) # [BP, N, C] --> [B, P, N, C] x = x.contiguous().view( info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1 ) batch_size, pixels, num_patches, channels = x.size() num_patch_h = info_dict["num_patches_h"] num_patch_w = info_dict["num_patches_w"] # [B, P, N, C] -> [B, C, N, P] x = x.transpose(1, 3) # [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w] x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w) # [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w] x = x.transpose(1, 2) # [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W] x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w) if info_dict["interpolate"]: x = F.interpolate( x, size=info_dict["orig_size"], mode="bilinear", align_corners=False, ) return x def forward(self, x: Tensor) -> Tensor: res = x fm = self.local_rep(x) # convert feature map to patches patches, info_dict = self.unfolding(fm) # learn global representations for transformer_layer in self.global_rep: patches = transformer_layer(patches) # [B x Patch x Patches x C] -> [B x C x Patches x Patch] fm = self.folding(x=patches, info_dict=info_dict) fm = self.conv_proj(fm) fm = self.fusion(torch.cat((res, fm), dim=1)) return fm class MobileViT(nn.Module): """ This class implements the `MobileViT architecture `_ """ def __init__(self, model_cfg: Dict, num_classes: int = 1000): super().__init__() image_channels = 3 out_channels = 16 self.conv_1 = ConvLayer( in_channels=image_channels, out_channels=out_channels, kernel_size=3, stride=2 ) self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"]) self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"]) self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"]) self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"]) self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"]) exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960) self.conv_1x1_exp = ConvLayer( in_channels=out_channels, out_channels=exp_channels, kernel_size=1 ) self.classifier = nn.Sequential() self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1)) self.classifier.add_module(name="flatten", module=nn.Flatten()) if 0.0 < model_cfg["cls_dropout"] < 1.0: self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"])) self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes)) # weight init self.apply(self.init_parameters) def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]: block_type = cfg.get("block_type", "mobilevit") if block_type.lower() == "mobilevit": return self._make_mit_layer(input_channel=input_channel, cfg=cfg) else: return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg) @staticmethod def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]: output_channels = cfg.get("out_channels") num_blocks = cfg.get("num_blocks", 2) expand_ratio = cfg.get("expand_ratio", 4) block = [] for i in range(num_blocks): stride = cfg.get("stride", 1) if i == 0 else 1 layer = InvertedResidual( in_channels=input_channel, out_channels=output_channels, stride=stride, expand_ratio=expand_ratio ) block.append(layer) input_channel = output_channels return nn.Sequential(*block), input_channel @staticmethod def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]: stride = cfg.get("stride", 1) block = [] if stride == 2: layer = InvertedResidual( in_channels=input_channel, out_channels=cfg.get("out_channels"), stride=stride, expand_ratio=cfg.get("mv_expand_ratio", 4) ) block.append(layer) input_channel = cfg.get("out_channels") transformer_dim = cfg["transformer_channels"] ffn_dim = cfg.get("ffn_dim") num_heads = cfg.get("num_heads", 4) head_dim = transformer_dim // num_heads if transformer_dim % head_dim != 0: raise ValueError("Transformer input dimension should be divisible by head dimension. " "Got {} and {}.".format(transformer_dim, head_dim)) block.append(MobileViTBlock( in_channels=input_channel, transformer_dim=transformer_dim, ffn_dim=ffn_dim, n_transformer_blocks=cfg.get("transformer_blocks", 1), patch_h=cfg.get("patch_h", 2), patch_w=cfg.get("patch_w", 2), dropout=cfg.get("dropout", 0.1), ffn_dropout=cfg.get("ffn_dropout", 0.0), attn_dropout=cfg.get("attn_dropout", 0.1), head_dim=head_dim, conv_ksize=3 )) return nn.Sequential(*block), input_channel @staticmethod def init_parameters(m): if isinstance(m, nn.Conv2d): if m.weight is not None: nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): if m.weight is not None: nn.init.ones_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.Linear,)): if m.weight is not None: nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) else: pass def forward(self, x: Tensor) -> Tensor: x = self.conv_1(x) x = self.layer_1(x) x = self.layer_2(x) x = self.layer_3(x) x = self.layer_4(x) x = self.layer_5(x) x = self.conv_1x1_exp(x) x = self.classifier(x) return x def mobile_vit_xx_small(num_classes: int = 1000): # pretrain weight link # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt config = get_config("xx_small") m = MobileViT(config, num_classes=num_classes) return m def mobile_vit_x_small(num_classes: int = 1000): # pretrain weight link # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt config = get_config("x_small") m = MobileViT(config, num_classes=num_classes) return m def mobile_vit_small(num_classes: int = 1000): # pretrain weight link # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt config = get_config("small") m = MobileViT(config, num_classes=num_classes) return m