42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import argparse
|
|
import os
|
|
from timm.models import create_model
|
|
from utils import ROOT_PATH
|
|
import torch
|
|
from torch import nn
|
|
from Normalize import Normalize, TfNormalize
|
|
from torch_nets import (
|
|
tf_inception_v3,
|
|
tf_resnet_v2_50,
|
|
tf_resnet_v2_101,
|
|
)
|
|
|
|
MODEL_NAMES = ['vit_base_patch16_224',
|
|
'deit_base_distilled_patch16_224',
|
|
'levit_256',
|
|
'pit_b_224',
|
|
'cait_s24_224',
|
|
'convit_base',
|
|
'tnt_s_patch16_224',
|
|
'visformer_small']
|
|
|
|
CORR_CKPTS = ['jx_vit_base_p16_224-4ee7a4dc.pth',
|
|
'deit_base_distilled_patch16_224-df68dfff.pth',
|
|
'LeViT-256-13b5763e.pth',
|
|
'pit_b_820.pth',
|
|
'S24_224.pth',
|
|
'convit_base.pth',
|
|
'tnt_s_patch16_224.pth.tar',
|
|
'visformer_small-839e1f5b.pth']
|
|
|
|
def get_model(model_name):
|
|
if model_name in MODEL_NAMES:
|
|
model = create_model(
|
|
model_name,
|
|
pretrained=True,
|
|
num_classes=1000,
|
|
in_chans=3,
|
|
global_pool=None,
|
|
scriptable=False)
|
|
print ('Loading Model.')
|
|
return model |