Graduation_Project/WZM/model/pyramid_vig.py

241 lines
8.9 KiB
Python
Raw Permalink Normal View History

2024-06-24 19:41:48 +08:00
# 2022.06.17-Changed for building ViG model
# Huawei Technologies Co., Ltd. <foss@huawei.com>
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq
# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
# from timm.models.helpers import load_pretrained
# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# from timm.models.registry import register_model
from .gcn_lib import DropPath
from .gcn_lib import Grapher, act_layer
from .GAT import GAT, GATopt
# from gcn_lib import DropPath
# from gcn_lib import Grapher, act_layer
# from GAT import GAT, GATopt
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
# def _cfg(url='', **kwargs):
# return {
# 'url': url,
# 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
# 'crop_pct': .9, 'interpolation': 'bicubic',
# 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
# 'first_conv': 'patch_embed.proj', 'classifier': 'head',
# **kwargs
# }
default_cfgs = {
'vig_224_gelu': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vig_b_224_gelu': _cfg(
crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Sequential(
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
nn.BatchNorm2d(hidden_features),
)
self.act = act_layer(act)
self.fc2 = nn.Sequential(
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
nn.BatchNorm2d(out_features),
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x#.reshape(B, C, N, 1)
class Stem(nn.Module):
""" Image to Visual Embedding
Overlap: https://arxiv.org/pdf/2106.13797.pdf
"""
def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//2),
act_layer(act),
nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim),
act_layer(act),
nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
)
def forward(self, x):
x = self.convs(x)
return x
class Downsample(nn.Module):
""" Convolution-based downsample
"""
def __init__(self, in_dim=3, out_dim=768):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim),
)
def forward(self, x):
x = self.conv(x)
return x
class DeepGCN(torch.nn.Module):
def __init__(self, opt):
super(DeepGCN, self).__init__()
print(opt)
k = opt.k
act = opt.act
norm = opt.norm
bias = opt.bias
epsilon = opt.epsilon
stochastic = opt.use_stochastic
conv = opt.conv
emb_dims = opt.emb_dims
drop_path = opt.drop_path
blocks = opt.blocks
self.n_blocks = sum(blocks)
channels = opt.channels
reduce_ratios = [4, 2, 1, 1]
dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay rule
num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k
max_dilation = 49 // max(num_knn)
self.stem = Stem(out_dim=channels[0], act=act)
self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))
HW = 224 // 4 * 224 // 4
self.backbone = nn.ModuleList([])
idx = 0
for i in range(len(blocks)):
if i > 0:
# self.backbone.append(Seq(Downsample(channels[i-1], channels[i])))
self.backbone.append(Downsample(channels[i-1], channels[i]))
HW = HW // 4
for j in range(blocks[i]):
self.backbone += [
Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,
bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],
relative_pos=True),
FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx])
)]
idx += 1
self.backbone = Seq(*self.backbone)
# self.prediction = Seq(nn.Conv2d(channels[-1], 512, 1, bias=True),
# nn.BatchNorm2d(512),
# act_layer(act),
# nn.Dropout(opt.dropout),
# nn.Conv2d(512, opt.n_classes, 1, bias=True))
self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
nn.BatchNorm2d(1024),
act_layer(act),
nn.Dropout(opt.dropout),
nn.Conv2d(1024, opt.n_classes, 1, bias=True))
self.model_init()
def model_init(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
m.weight.requires_grad = True
if m.bias is not None:
m.bias.data.zero_()
m.bias.requires_grad = True
def forward(self, inputs):
x = self.stem(inputs) + self.pos_embed
# print('x shape : ', x.shape)
B, C, H, W = x.shape
low_feature, mid_feature = None, None
for i in range(len(self.backbone)):
# print('x shape : ', x.shape)
x = self.backbone[i](x)
if i == 5:
low_feature = x
if i == 12:
mid_feature = x
# print('x shape befo adapt: ', x.shape)
x = F.adaptive_avg_pool2d(x, 1)
# print('x shape : ', x.shape)
# print('x shape after prediction : ', self.prediction(x).shape)
solo_feature = self.prediction(x).squeeze(-1).squeeze(-1)
return low_feature, mid_feature, solo_feature
def pvig_ti_224_gelu(pretrained=False, **kwargs):
class OptInit:
def __init__(self, num_classes=1000, drop_path_rate=0.0, **kwargs):
self.k = 9 # neighbor num (default:9)
self.conv = 'mr' # graph conv layer {edge, mr}
self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
self.norm = 'batch' # batch or instance normalization {batch, instance}
self.bias = True # bias of conv layer True or False
self.dropout = 0.0 # dropout rate
self.use_dilation = True # use dilated knn or not
self.epsilon = 0.2 # stochastic epsilon for gcn
self.use_stochastic = False # stochastic for gcn, True or False
self.drop_path = drop_path_rate
self.blocks = [2,2,6,2] # number of basic blocks in the backbone
self.channels = [48, 96, 240, 384] # number of channels of deep features
self.n_classes = num_classes # Dimension of out_channels
self.emb_dims = 1024 # Dimension of embeddings
opt = OptInit(**kwargs)
model = DeepGCN(opt)
model.default_cfg = default_cfgs['vig_224_gelu']
# 加载预训练模型
state_dict = torch.load('/home/wzm/GAC/model/pvig_ti_78.5.pth.tar')
model.load_state_dict(state_dict, strict=False)
# model.backbone[2].add_module('GAT', GAT(GATopt(96, 1)))
# model.backbone[5].add_module('GAT', GAT(GATopt(240, 1)))
# model.backbone[12].add_module('GAT', GAT(GATopt(384, 1)))
model.prediction = Seq(nn.Conv2d(384, 512, 1, bias=True),
nn.BatchNorm2d(512),
act_layer('gelu'),
nn.Dropout(opt.dropout),
nn.Conv2d(512, 512, 1, bias=True))
# Freeze pre-train model
# for param in model.parameters():
# param.requires_grad = False
# # unfreeze prediction param
# for name, param in model.named_parameters():
# if "prediction" in name:
# param.requires_grad = True
return model