def get_config(mode: str = "xxs") -> dict: if mode == "xx_small": mv2_exp_mult = 2 config = { "layer1": { "out_channels": 16, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 24, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 48, "transformer_channels": 64, "ffn_dim": 128, "transformer_blocks": 2, "patch_h": 2, # 8, "patch_w": 2, # 8, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 64, "transformer_channels": 80, "ffn_dim": 160, "transformer_blocks": 4, "patch_h": 2, # 4, "patch_w": 2, # 4, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 80, "transformer_channels": 96, "ffn_dim": 192, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, "cls_dropout": 0.1 } elif mode == "x_small": mv2_exp_mult = 4 config = { "layer1": { "out_channels": 32, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 48, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 64, "transformer_channels": 96, "ffn_dim": 192, "transformer_blocks": 2, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 80, "transformer_channels": 120, "ffn_dim": 240, "transformer_blocks": 4, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 96, "transformer_channels": 144, "ffn_dim": 288, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, "cls_dropout": 0.1 } elif mode == "small": mv2_exp_mult = 4 config = { "layer1": { "out_channels": 32, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 64, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 96, "transformer_channels": 144, "ffn_dim": 288, "transformer_blocks": 2, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 128, "transformer_channels": 192, "ffn_dim": 384, "transformer_blocks": 4, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 160, "transformer_channels": 240, "ffn_dim": 480, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "num_heads": 4, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, "cls_dropout": 0.1 } else: raise NotImplementedError for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]: config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0}) return config