brain-tumor_image_classific.../efficientnet_v2.py

753 lines
31 KiB
Python
Raw Normal View History

2024-06-21 10:34:43 +08:00
import collections.abc as container_abc
from collections import OrderedDict
from math import ceil, floor
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from coordatt import CoordAtt
# class CoordAtt(nn.Module):
# def __init__(self, inp, oup, groups=32):
# super(CoordAtt, self).__init__()
# self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
# self.pool_w = nn.AdaptiveAvgPool2d((1, None))
# mip = max(8, inp // groups)
# self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
# self.bn1 = nn.BatchNorm2d(mip)
# self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
# self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
# self.relu = h_swish()
# def forward(self, x):
# identity = x
# n,c,h,w = x.size()
# x_h = self.pool_h(x)
# x_w = self.pool_w(x).permute(0, 1, 3, 2)
# y = torch.cat([x_h, x_w], dim=2)
# y = self.conv1(y)
# y = self.bn1(y)
# y = self.relu(y)
# x_h, x_w = torch.split(y, [h, w], dim=2)
# x_w = x_w.permute(0, 1, 3, 2)
# x_h = self.conv2(x_h).sigmoid()
# x_w = self.conv3(x_w).sigmoid()
# x_h = x_h.expand(-1, -1, h, w)
# x_w = x_w.expand(-1, -1, h, w)
# y = identity * x_w * x_h
# return y
def _pair(x):
if isinstance(x, container_abc.Iterable):
return x
return (x, x)
def torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride):
if in_spatial_shape is None:
return None
# in_spatial_shape -> [H,W]
hin, win = _pair(in_spatial_shape)
kh, kw = _pair(kernel_size)
sh, sw = _pair(stride)
# dilation and padding are ignored since they are always fixed in efficientnetV2
hout = int(floor((hin - kh - 1) / sh + 1))
wout = int(floor((win - kw - 1) / sw + 1))
return hout, wout
def get_activation(act_fn: str, **kwargs):
if act_fn in ('silu', 'swish'):
return nn.SiLU(**kwargs)
elif act_fn == 'relu':
return nn.ReLU(**kwargs)
elif act_fn == 'relu6':
return nn.ReLU6(**kwargs)
elif act_fn == 'elu':
return nn.ELU(**kwargs)
elif act_fn == 'leaky_relu':
return nn.LeakyReLU(**kwargs)
elif act_fn == 'selu':
return nn.SELU(**kwargs)
elif act_fn == 'mish':
return nn.Mish(**kwargs)
else:
raise ValueError('Unsupported act_fn {}'.format(act_fn))
def round_filters(filters, width_coefficient, depth_divisor=8):
"""Round number of filters based on depth multiplier."""
min_depth = depth_divisor
filters *= width_coefficient
new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor)
return int(new_filters)
def round_repeats(repeats, depth_coefficient):
"""Round number of filters based on depth multiplier."""
return int(ceil(depth_coefficient * repeats))
class DropConnect(nn.Module):
def __init__(self, rate=0.5):
super(DropConnect, self).__init__()
self.keep_prob = None
self.set_rate(rate)
def set_rate(self, rate):
if not 0 <= rate < 1:
raise ValueError("rate must be 0<=rate<1, got {} instead".format(rate))
self.keep_prob = 1 - rate
def forward(self, x):
if self.training:
random_tensor = self.keep_prob + torch.rand([x.size(0), 1, 1, 1],
dtype=x.dtype,
device=x.device)
binary_tensor = torch.floor(random_tensor)
return torch.mul(torch.div(x, self.keep_prob), binary_tensor)
else:
return x
class SamePaddingConv2d(nn.Module):
def __init__(self,
in_spatial_shape,
in_channels,
out_channels,
kernel_size,
stride,
dilation=1,
enforce_in_spatial_shape=False,
**kwargs):
super(SamePaddingConv2d, self).__init__()
self._in_spatial_shape = _pair(in_spatial_shape)
# e.g. throw exception if input spatial shape does not match in_spatial_shape
# when calling self.forward()
self.enforce_in_spatial_shape = enforce_in_spatial_shape
kernel_size = _pair(kernel_size)
stride = _pair(stride)
dilation = _pair(dilation)
in_height, in_width = self._in_spatial_shape
filter_height, filter_width = kernel_size
stride_heigth, stride_width = stride
dilation_height, dilation_width = dilation
out_height = int(ceil(float(in_height) / float(stride_heigth)))
out_width = int(ceil(float(in_width) / float(stride_width)))
pad_along_height = max((out_height - 1) * stride_heigth +
filter_height + (filter_height - 1) * (dilation_height - 1) - in_height, 0)
pad_along_width = max((out_width - 1) * stride_width +
filter_width + (filter_width - 1) * (dilation_width - 1) - in_width, 0)
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
paddings = (pad_left, pad_right, pad_top, pad_bottom)
if any(p > 0 for p in paddings):
self.zero_pad = nn.ZeroPad2d(paddings)
else:
self.zero_pad = None
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
self._out_spatial_shape = (out_height, out_width)
@property
def out_spatial_shape(self):
return self._out_spatial_shape
def check_spatial_shape(self, x):
if x.size(2) != self._in_spatial_shape[0] or \
x.size(3) != self._in_spatial_shape[1]:
raise ValueError(
"Expected input spatial shape {}, got {} instead".format(self._in_spatial_shape, x.shape[2:]))
def forward(self, x):
if self.enforce_in_spatial_shape:
self.check_spatial_shape(x)
if self.zero_pad is not None:
x = self.zero_pad(x)
x = self.conv(x)
return x
class SqueezeExcitate(nn.Module):
def __init__(self,
in_channels,
se_size,
activation=None):
super(SqueezeExcitate, self).__init__()
self.dim_reduce = nn.Conv2d(in_channels=in_channels,
out_channels=se_size,
kernel_size=1)
self.dim_restore = nn.Conv2d(in_channels=se_size,
out_channels=in_channels,
kernel_size=1)
self.activation = F.relu if activation is None else activation
def forward(self, x):
inp = x
x = F.adaptive_avg_pool2d(x, (1, 1))
x = self.dim_reduce(x)
x = self.activation(x)
x = self.dim_restore(x)
x = torch.sigmoid(x)
return torch.mul(inp, x)
class MBConvBlockV2(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
expansion_factor,
act_fn,
act_kwargs=None,
bn_epsilon=None,
bn_momentum=None,
se_size=None,
drop_connect_rate=None,
bias=False,
tf_style_conv=False,
in_spatial_shape=None):
super().__init__()
if act_kwargs is None:
act_kwargs = {}
exp_channels = in_channels * expansion_factor
self.ops_lst = []
# expansion convolution
if expansion_factor != 1:
self.expand_conv = nn.Conv2d(in_channels=in_channels,
out_channels=exp_channels,
kernel_size=1,
bias=bias)
self.expand_bn = nn.BatchNorm2d(num_features=exp_channels,
eps=bn_epsilon,
momentum=bn_momentum)
self.expand_act = get_activation(act_fn, **act_kwargs)
self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])
# depth-wise convolution
if tf_style_conv:
self.dp_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape,
in_channels=exp_channels,
out_channels=exp_channels,
kernel_size=kernel_size,
stride=stride,
groups=exp_channels,
bias=bias)
self.out_spatial_shape = self.dp_conv.out_spatial_shape
else:
self.dp_conv = nn.Conv2d(in_channels=exp_channels,
out_channels=exp_channels,
kernel_size=kernel_size,
stride=stride,
padding=1,
groups=exp_channels,
bias=bias)
self.out_spatial_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride)
self.dp_bn = nn.BatchNorm2d(num_features=exp_channels,
eps=bn_epsilon,
momentum=bn_momentum)
self.dp_act = get_activation(act_fn, **act_kwargs)
self.ops_lst.extend([self.dp_conv, self.dp_bn, self.dp_act])
# Squeeze and Excitate
if se_size is not None:
self.se = SqueezeExcitate(exp_channels,
se_size,
activation=get_activation(act_fn, **act_kwargs))
self.ops_lst.append(self.se)
# projection layer
self.project_conv = nn.Conv2d(in_channels=exp_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias)
self.project_bn = nn.BatchNorm2d(num_features=out_channels,
eps=bn_epsilon,
momentum=bn_momentum)
# no activation function in projection layer
self.ops_lst.extend([self.project_conv, self.project_bn])
self.skip_enabled = in_channels == out_channels and stride == 1
if self.skip_enabled and drop_connect_rate is not None:
self.drop_connect = DropConnect(drop_connect_rate)
self.ops_lst.append(self.drop_connect)
def forward(self, x):
inp = x
for op in self.ops_lst:
x = op(x)
if self.skip_enabled:
return x + inp
else:
return x
class FusedMBConvBlockV2(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
expansion_factor,
act_fn,
act_kwargs=None,
bn_epsilon=None,
bn_momentum=None,
se_size=None,
drop_connect_rate=None,
bias=False,
tf_style_conv=False,
in_spatial_shape=None):
super().__init__()
if act_kwargs is None:
act_kwargs = {}
exp_channels = in_channels * expansion_factor
self.ops_lst = []
# expansion convolution
expansion_out_shape = in_spatial_shape
if expansion_factor != 1:
if tf_style_conv:
self.expand_conv = SamePaddingConv2d(in_spatial_shape=in_spatial_shape,
in_channels=in_channels,
out_channels=exp_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias)
expansion_out_shape = self.expand_conv.out_spatial_shape
else:
self.expand_conv = nn.Conv2d(in_channels=in_channels,
out_channels=exp_channels,
kernel_size=kernel_size,
padding=1,
stride=stride,
bias=bias)
expansion_out_shape = torch_conv_out_spatial_shape(in_spatial_shape, kernel_size, stride)
self.expand_bn = nn.BatchNorm2d(num_features=exp_channels,
eps=bn_epsilon,
momentum=bn_momentum)
self.expand_act = get_activation(act_fn, **act_kwargs)
self.ops_lst.extend([self.expand_conv, self.expand_bn, self.expand_act])
# Squeeze and Excitate
if se_size is not None:
self.se = SqueezeExcitate(exp_channels,
se_size,
activation=get_activation(act_fn, **act_kwargs))
self.ops_lst.append(self.se)
# projection layer
kernel_size = 1 if expansion_factor != 1 else kernel_size
stride = 1 if expansion_factor != 1 else stride
if tf_style_conv:
self.project_conv = SamePaddingConv2d(in_spatial_shape=expansion_out_shape,
in_channels=exp_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias)
self.out_spatial_shape = self.project_conv.out_spatial_shape
else:
self.project_conv = nn.Conv2d(in_channels=exp_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1 if kernel_size > 1 else 0,
bias=bias)
self.out_spatial_shape = torch_conv_out_spatial_shape(expansion_out_shape, kernel_size, stride)
self.project_bn = nn.BatchNorm2d(num_features=out_channels,
eps=bn_epsilon,
momentum=bn_momentum)
self.ops_lst.extend(
[self.project_conv, self.project_bn])
if expansion_factor == 1:
self.project_act = get_activation(act_fn, **act_kwargs)
self.ops_lst.append(self.project_act)
self.skip_enabled = in_channels == out_channels and stride == 1
if self.skip_enabled and drop_connect_rate is not None:
self.drop_connect = DropConnect(drop_connect_rate)
self.ops_lst.append(self.drop_connect)
def forward(self, x):
inp = x
for op in self.ops_lst:
x = op(x)
if self.skip_enabled:
return x + inp
else:
return x
class EfficientNetV2(nn.Module):
_models = {'b0': {'num_repeat': [1, 2, 2, 3, 5, 8],
'kernel_size': [3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2],
'expand_ratio': [1, 4, 4, 4, 6, 6],
'in_channel': [32, 16, 32, 48, 96, 112],
'out_channel': [16, 32, 48, 96, 112, 192],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, True],
'width_coefficient': 1.0,
'depth_coefficient': 1.0,
'train_size': 192,
'eval_size': 224,
'dropout': 0.2,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVBhWkZRcWNXR3dINmRLP2U9UUI5ZndH/root/content',
'model_name': 'efficientnet_v2_b0_21k_ft1k-a91e14c5.pth'},
'b1': {'num_repeat': [1, 2, 2, 3, 5, 8],
'kernel_size': [3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2],
'expand_ratio': [1, 4, 4, 4, 6, 6],
'in_channel': [32, 16, 32, 48, 96, 112],
'out_channel': [16, 32, 48, 96, 112, 192],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, True],
'width_coefficient': 1.0,
'depth_coefficient': 1.1,
'train_size': 192,
'eval_size': 240,
'dropout': 0.2,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVJnVGV5UndSY2J2amwtP2U9dTBiV1lO/root/content',
'model_name': 'efficientnet_v2_b1_21k_ft1k-58f4fb47.pth'},
'b2': {'num_repeat': [1, 2, 2, 3, 5, 8],
'kernel_size': [3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2],
'expand_ratio': [1, 4, 4, 4, 6, 6],
'in_channel': [32, 16, 32, 48, 96, 112],
'out_channel': [16, 32, 48, 96, 112, 192],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, True],
'width_coefficient': 1.1,
'depth_coefficient': 1.2,
'train_size': 208,
'eval_size': 260,
'dropout': 0.3,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVY4M2NySVFZbU41X0tGP2U9ZERZVmxK/root/content',
'model_name': 'efficientnet_v2_b2_21k_ft1k-db4ac0ee.pth'},
'b3': {'num_repeat': [1, 2, 2, 3, 5, 8],
'kernel_size': [3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2],
'expand_ratio': [1, 4, 4, 4, 6, 6],
'in_channel': [32, 16, 32, 48, 96, 112],
'out_channel': [16, 32, 48, 96, 112, 192],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, True],
'width_coefficient': 1.2,
'depth_coefficient': 1.4,
'train_size': 240,
'eval_size': 300,
'dropout': 0.3,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlnUVpkamdZUzhhaDdtTTZLP2U9anA4VWN2/root/content',
'model_name': 'efficientnet_v2_b3_21k_ft1k-3da5874c.pth'},
's': {'num_repeat': [2, 4, 4, 6, 9, 15],
'kernel_size': [3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2],
'expand_ratio': [1, 4, 4, 4, 6, 6],
'in_channel': [24, 24, 48, 64, 128, 160],
'out_channel': [24, 48, 64, 128, 160, 256],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, True],
'width_coefficient': 1.0,
'depth_coefficient': 1.0,
'train_size': 300,
'eval_size': 384,
'dropout': 0.2,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllbFF5VWJOZzd0cmhBbm8/root/content',
'model_name': 'efficientnet_v2_s_21k_ft1k-dbb43f38.pth'},
'm': {'num_repeat': [3, 5, 5, 7, 14, 18, 5],
'kernel_size': [3, 3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2, 1],
'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
'in_channel': [24, 24, 48, 80, 160, 176, 304],
'out_channel': [24, 48, 80, 160, 176, 304, 512],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, False, True],
'width_coefficient': 1.0,
'depth_coefficient': 1.0,
'train_size': 384,
'eval_size': 480,
'dropout': 0.3,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmllN1ZDazRFb0o1bnlyNUE/root/content',
'model_name': 'efficientnet_v2_m_21k_ft1k-da8e56c0.pth'},
'l': {'num_repeat': [4, 7, 7, 10, 19, 25, 7],
'kernel_size': [3, 3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2, 1],
'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
'in_channel': [32, 32, 64, 96, 192, 224, 384],
'out_channel': [32, 64, 96, 192, 224, 384, 640],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, False, True],
'feature_stages': [1, 2, 4, 6],
'width_coefficient': 1.0,
'depth_coefficient': 1.0,
'train_size': 384,
'eval_size': 480,
'dropout': 0.4,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmcmIyRHEtQTBhUTBhWVE/root/content',
'model_name': 'efficientnet_v2_l_21k_ft1k-08121eee.pth'},
'xl': {'num_repeat': [4, 8, 8, 16, 24, 32, 8],
'kernel_size': [3, 3, 3, 3, 3, 3, 3],
'stride': [1, 2, 2, 2, 1, 2, 1],
'expand_ratio': [1, 4, 4, 4, 6, 6, 6],
'in_channel': [32, 32, 64, 96, 192, 256, 512],
'out_channel': [32, 64, 96, 192, 256, 512, 640],
'se_ratio': [None, None, None, 0.25, 0.25, 0.25, 0.25],
'conv_type': [1, 1, 1, 0, 0, 0, 0],
'is_feature_stage': [False, True, True, False, True, False, True],
'feature_stages': [1, 2, 4, 6],
'width_coefficient': 1.0,
'depth_coefficient': 1.0,
'train_size': 384,
'eval_size': 512,
'dropout': 0.4,
'weight_url': 'https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGlRcHc5VGNjZmlmVXQtRHJLa21taUkxWkE/root/content',
'model_name': 'efficientnet_v2_xl_21k_ft1k-1fcc9744.pth'}}
def __init__(self,
model_name,
in_channels=3,
n_classes=1000,
tf_style_conv=False,
in_spatial_shape=None,
activation='silu',
activation_kwargs=None,
bias=False,
drop_connect_rate=0.2,
dropout_rate=None,
bn_epsilon=1e-3,
bn_momentum=0.01,
pretrained=False,
progress=False,
):
super().__init__()
self.blocks = nn.ModuleList()
self.model_name = model_name
self.cfg = self._models[model_name]
if tf_style_conv and in_spatial_shape is None:
in_spatial_shape = self.cfg['eval_size']
activation_kwargs = {} if activation_kwargs is None else activation_kwargs
dropout_rate = self.cfg['dropout'] if dropout_rate is None else dropout_rate
_input_ch = in_channels
self.feature_block_ids = []
# stem
if tf_style_conv:
self.stem_conv = SamePaddingConv2d(
in_spatial_shape=in_spatial_shape,
in_channels=in_channels,
out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
kernel_size=3,
stride=2,
bias=bias
)
in_spatial_shape = self.stem_conv.out_spatial_shape
else:
self.stem_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
kernel_size=3,
stride=2,
padding=1,
bias=bias
)
self.stem_bn = nn.BatchNorm2d(
num_features=round_filters(self.cfg['in_channel'][0], self.cfg['width_coefficient']),
eps=bn_epsilon,
momentum=bn_momentum)
self.stem_act = get_activation(activation, **activation_kwargs)
drop_connect_rates = self.get_dropconnect_rates(drop_connect_rate)
stages = zip(*[self.cfg[x] for x in
['num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'in_channel', 'out_channel', 'se_ratio',
'conv_type', 'is_feature_stage']])
idx = 0
for stage_args in stages:
(num_repeat, kernel_size, stride, expand_ratio,
in_channels, out_channels, se_ratio, conv_type, is_feature_stage) = stage_args
in_channels = round_filters(
in_channels, self.cfg['width_coefficient'])
out_channels = round_filters(
out_channels, self.cfg['width_coefficient'])
num_repeat = round_repeats(
num_repeat, self.cfg['depth_coefficient'])
conv_block = MBConvBlockV2 if conv_type == 0 else FusedMBConvBlockV2
for _ in range(num_repeat):
se_size = None if se_ratio is None else max(1, int(in_channels * se_ratio))
_b = conv_block(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
expansion_factor=expand_ratio,
act_fn=activation,
act_kwargs=activation_kwargs,
bn_epsilon=bn_epsilon,
bn_momentum=bn_momentum,
se_size=se_size,
drop_connect_rate=drop_connect_rates[idx],
bias=bias,
tf_style_conv=tf_style_conv,
in_spatial_shape=in_spatial_shape
)
self.blocks.append(_b)
idx += 1
if tf_style_conv:
in_spatial_shape = _b.out_spatial_shape
in_channels = out_channels
stride = 1
if is_feature_stage:
self.feature_block_ids.append(idx - 1)
head_conv_out_channels = round_filters(1280, self.cfg['width_coefficient'])
self.head_conv = nn.Conv2d(in_channels=in_channels,
out_channels=head_conv_out_channels,
kernel_size=1,
bias=bias)
self.head_bn = nn.BatchNorm2d(num_features=head_conv_out_channels,
eps=bn_epsilon,
momentum=bn_momentum)
self.head_act = get_activation(activation, **activation_kwargs)
# self.CoordAtt1 = CoordAtt(head_conv_out_channels,head_conv_out_channels)
self.dropout = nn.Dropout(p=dropout_rate)
self.avpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(head_conv_out_channels, n_classes)
if pretrained:
self._load_state(_input_ch, n_classes, progress, tf_style_conv)
return
def _load_state(self, in_channels, n_classes, progress, tf_style_conv):
state_dict = model_zoo.load_url(self.cfg['weight_url'],
progress=progress,
file_name=self.cfg['model_name'])
strict = True
if not tf_style_conv:
state_dict = OrderedDict(
[(k.replace('.conv.', '.'), v) if '.conv.' in k else (k, v) for k, v in state_dict.items()])
if in_channels != 3:
if tf_style_conv:
state_dict.pop('stem_conv.conv.weight')
else:
state_dict.pop('stem_conv.weight')
strict = False
if n_classes != 1000:
state_dict.pop('fc.weight')
state_dict.pop('fc.bias')
strict = False
self.load_state_dict(state_dict, strict=strict)
print("Model weights loaded successfully.")
def get_dropconnect_rates(self, drop_connect_rate):
nr = self.cfg['num_repeat']
dc = self.cfg['depth_coefficient']
total = sum(round_repeats(nr[i], dc) for i in range(len(nr)))
return [drop_connect_rate * i / total for i in range(total)]
def get_features(self, x):
x = self.stem_act(self.stem_bn(self.stem_conv(x)))
features = []
feat_idx = 0
for block_idx, block in enumerate(self.blocks):
x = block(x)
if block_idx == self.feature_block_ids[feat_idx]:
features.append(x)
feat_idx += 1
return features
def forward(self, x):
x = self.stem_act(self.stem_bn(self.stem_conv(x)))
for block in self.blocks:
x = block(x)
x = self.head_act(self.head_bn(self.head_conv(x)))
# x = self.CoordAtt1(x)
x = self.dropout(torch.flatten(self.avpool(x), 1))
x = self.fc(x)
return x