177 lines
5.6 KiB
Python
177 lines
5.6 KiB
Python
def get_config(mode: str = "xxs") -> dict:
|
|
if mode == "xx_small":
|
|
mv2_exp_mult = 2
|
|
config = {
|
|
"layer1": {
|
|
"out_channels": 16,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 1,
|
|
"stride": 1,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer2": {
|
|
"out_channels": 24,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 3,
|
|
"stride": 2,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer3": { # 28x28
|
|
"out_channels": 48,
|
|
"transformer_channels": 64,
|
|
"ffn_dim": 128,
|
|
"transformer_blocks": 2,
|
|
"patch_h": 2, # 8,
|
|
"patch_w": 2, # 8,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer4": { # 14x14
|
|
"out_channels": 64,
|
|
"transformer_channels": 80,
|
|
"ffn_dim": 160,
|
|
"transformer_blocks": 4,
|
|
"patch_h": 2, # 4,
|
|
"patch_w": 2, # 4,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer5": { # 7x7
|
|
"out_channels": 80,
|
|
"transformer_channels": 96,
|
|
"ffn_dim": 192,
|
|
"transformer_blocks": 3,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"last_layer_exp_factor": 4,
|
|
"cls_dropout": 0.1
|
|
}
|
|
elif mode == "x_small":
|
|
mv2_exp_mult = 4
|
|
config = {
|
|
"layer1": {
|
|
"out_channels": 32,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 1,
|
|
"stride": 1,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer2": {
|
|
"out_channels": 48,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 3,
|
|
"stride": 2,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer3": { # 28x28
|
|
"out_channels": 64,
|
|
"transformer_channels": 96,
|
|
"ffn_dim": 192,
|
|
"transformer_blocks": 2,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer4": { # 14x14
|
|
"out_channels": 80,
|
|
"transformer_channels": 120,
|
|
"ffn_dim": 240,
|
|
"transformer_blocks": 4,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer5": { # 7x7
|
|
"out_channels": 96,
|
|
"transformer_channels": 144,
|
|
"ffn_dim": 288,
|
|
"transformer_blocks": 3,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"last_layer_exp_factor": 4,
|
|
"cls_dropout": 0.1
|
|
}
|
|
elif mode == "small":
|
|
mv2_exp_mult = 4
|
|
config = {
|
|
"layer1": {
|
|
"out_channels": 32,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 1,
|
|
"stride": 1,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer2": {
|
|
"out_channels": 64,
|
|
"expand_ratio": mv2_exp_mult,
|
|
"num_blocks": 3,
|
|
"stride": 2,
|
|
"block_type": "mv2",
|
|
},
|
|
"layer3": { # 28x28
|
|
"out_channels": 96,
|
|
"transformer_channels": 144,
|
|
"ffn_dim": 288,
|
|
"transformer_blocks": 2,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer4": { # 14x14
|
|
"out_channels": 128,
|
|
"transformer_channels": 192,
|
|
"ffn_dim": 384,
|
|
"transformer_blocks": 4,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"layer5": { # 7x7
|
|
"out_channels": 160,
|
|
"transformer_channels": 240,
|
|
"ffn_dim": 480,
|
|
"transformer_blocks": 3,
|
|
"patch_h": 2,
|
|
"patch_w": 2,
|
|
"stride": 2,
|
|
"mv_expand_ratio": mv2_exp_mult,
|
|
"num_heads": 4,
|
|
"block_type": "mobilevit",
|
|
},
|
|
"last_layer_exp_factor": 4,
|
|
"cls_dropout": 0.1
|
|
}
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
|
|
config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
|
|
|
|
return config
|