342 lines
13 KiB
Python
342 lines
13 KiB
Python
import torch
|
|
from timm.models.layers import DropPath
|
|
from torch.nn.functional import cross_entropy, dropout, one_hot, softmax
|
|
|
|
|
|
def init_weight(model):
|
|
import math
|
|
for m in model.modules():
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
fan_out = fan_out // m.groups
|
|
torch.nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
|
|
if m.bias is not None:
|
|
torch.nn.init.zeros_(m.bias)
|
|
if isinstance(m, torch.nn.Linear):
|
|
init_range = 1.0 / math.sqrt(m.weight.size()[0])
|
|
torch.nn.init.uniform_(m.weight, -init_range, init_range)
|
|
if m.bias is not None:
|
|
torch.nn.init.zeros_(m.bias)
|
|
|
|
|
|
class Conv(torch.nn.Module):
|
|
def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, (k - 1) // 2, 1, g, bias=False)
|
|
self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01)
|
|
self.relu = activation
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.norm(self.conv(x)))
|
|
|
|
|
|
class SE(torch.nn.Module):
|
|
def __init__(self, ch, r):
|
|
super().__init__()
|
|
self.se = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1),
|
|
torch.nn.Conv2d(ch, ch // (4 * r), 1),
|
|
torch.nn.SiLU(),
|
|
torch.nn.Conv2d(ch // (4 * r), ch, 1),
|
|
torch.nn.Sigmoid())
|
|
|
|
def forward(self, x):
|
|
return x * self.se(x)
|
|
|
|
|
|
class Residual(torch.nn.Module):
|
|
"""
|
|
[https://arxiv.org/pdf/1801.04381.pdf]
|
|
"""
|
|
|
|
def __init__(self, in_ch, out_ch, s, r, dp_rate=0, fused=True):
|
|
super().__init__()
|
|
identity = torch.nn.Identity()
|
|
self.add = s == 1 and in_ch == out_ch
|
|
|
|
if fused:
|
|
features = [Conv(in_ch, r * in_ch, activation=torch.nn.SiLU(), k=3, s=s),
|
|
Conv(r * in_ch, out_ch, identity) if r != 1 else identity,
|
|
DropPath(dp_rate) if self.add else identity]
|
|
else:
|
|
features = [Conv(in_ch, r * in_ch, torch.nn.SiLU()) if r != 1 else identity,
|
|
Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, r * in_ch),
|
|
SE(r * in_ch, r), Conv(r * in_ch, out_ch, identity),
|
|
DropPath(dp_rate) if self.add else identity]
|
|
|
|
self.res = torch.nn.Sequential(*features)
|
|
|
|
def forward(self, x):
|
|
return x + self.res(x) if self.add else self.res(x)
|
|
|
|
|
|
class EfficientNet(torch.nn.Module):
|
|
"""
|
|
efficientnet-v2-s :
|
|
num_dep = [2, 4, 4, 6, 9, 15, 0]
|
|
filters = [24, 48, 64, 128, 160, 256, 256, 1280]
|
|
efficientnet-v2-m :
|
|
num_dep = [3, 5, 5, 7, 14, 18, 5]
|
|
filters = [24, 48, 80, 160, 176, 304, 512, 1280]
|
|
efficientnet-v2-l :
|
|
num_dep = [4, 7, 7, 10, 19, 25, 7]
|
|
filters = [32, 64, 96, 192, 224, 384, 640, 1280]
|
|
"""
|
|
|
|
def __init__(self, drop_rate=0, num_class=4):
|
|
super().__init__()
|
|
num_dep = [2, 4, 4, 6, 9, 15, 0]
|
|
filters = [24, 48, 64, 128, 160, 256, 256, 1280]
|
|
|
|
dp_index = 0
|
|
dp_rates = [x.item() for x in torch.linspace(0, 0.2, sum(num_dep))]
|
|
|
|
self.p1 = []
|
|
self.p2 = []
|
|
self.p3 = []
|
|
self.p4 = []
|
|
self.p5 = []
|
|
|
|
# p1/2
|
|
for i in range(num_dep[0]):
|
|
if i == 0:
|
|
self.p1.append(Conv(3, filters[0], torch.nn.SiLU(), 3, 2))
|
|
self.p1.append(Residual(filters[0], filters[0], 1, 1, dp_rates[dp_index]))
|
|
else:
|
|
self.p1.append(Residual(filters[0], filters[0], 1, 1, dp_rates[dp_index]))
|
|
dp_index += 1
|
|
# p2/4
|
|
for i in range(num_dep[1]):
|
|
if i == 0:
|
|
self.p2.append(Residual(filters[0], filters[1], 2, 4, dp_rates[dp_index]))
|
|
else:
|
|
self.p2.append(Residual(filters[1], filters[1], 1, 4, dp_rates[dp_index]))
|
|
dp_index += 1
|
|
# p3/8
|
|
for i in range(num_dep[2]):
|
|
if i == 0:
|
|
self.p3.append(Residual(filters[1], filters[2], 2, 4, dp_rates[dp_index]))
|
|
else:
|
|
self.p3.append(Residual(filters[2], filters[2], 1, 4, dp_rates[dp_index]))
|
|
dp_index += 1
|
|
# p4/16
|
|
for i in range(num_dep[3]):
|
|
if i == 0:
|
|
self.p4.append(Residual(filters[2], filters[3], 2, 4, dp_rates[dp_index], False))
|
|
else:
|
|
self.p4.append(Residual(filters[3], filters[3], 1, 4, dp_rates[dp_index], False))
|
|
dp_index += 1
|
|
for i in range(num_dep[4]):
|
|
if i == 0:
|
|
self.p4.append(Residual(filters[3], filters[4], 1, 6, dp_rates[dp_index], False))
|
|
else:
|
|
self.p4.append(Residual(filters[4], filters[4], 1, 6, dp_rates[dp_index], False))
|
|
dp_index += 1
|
|
# p5/32
|
|
for i in range(num_dep[5]):
|
|
if i == 0:
|
|
self.p5.append(Residual(filters[4], filters[5], 2, 6, dp_rates[dp_index], False))
|
|
else:
|
|
self.p5.append(Residual(filters[5], filters[5], 1, 6, dp_rates[dp_index], False))
|
|
dp_index += 1
|
|
for i in range(num_dep[6]):
|
|
if i == 0:
|
|
self.p5.append(Residual(filters[5], filters[6], 2, 6, dp_rates[dp_index], False))
|
|
else:
|
|
self.p5.append(Residual(filters[6], filters[6], 1, 6, dp_rates[dp_index], False))
|
|
dp_index += 1
|
|
|
|
self.p1 = torch.nn.Sequential(*self.p1)
|
|
self.p2 = torch.nn.Sequential(*self.p2)
|
|
self.p3 = torch.nn.Sequential(*self.p3)
|
|
self.p4 = torch.nn.Sequential(*self.p4)
|
|
self.p5 = torch.nn.Sequential(*self.p5)
|
|
|
|
self.fc1 = torch.nn.Sequential(Conv(filters[6], filters[7], torch.nn.SiLU()),
|
|
torch.nn.AdaptiveAvgPool2d(1),
|
|
torch.nn.Flatten())
|
|
self.fc2 = torch.nn.Linear(filters[7], num_class)
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
init_weight(self)
|
|
|
|
def forward(self, x):
|
|
x = self.p1(x)
|
|
x = self.p2(x)
|
|
x = self.p3(x)
|
|
x = self.p4(x)
|
|
x = self.p5(x)
|
|
|
|
x = self.fc1(x)
|
|
if self.drop_rate > 0:
|
|
x = dropout(x, self.drop_rate, self.training)
|
|
return self.fc2(x)
|
|
|
|
def export(self):
|
|
from timm.models.layers import Swish
|
|
for m in self.modules():
|
|
if type(m) is Conv and hasattr(m, 'relu'):
|
|
if isinstance(m.relu, torch.nn.SiLU):
|
|
m.relu = Swish()
|
|
if type(m) is SE:
|
|
if isinstance(m.se[2], torch.nn.SiLU):
|
|
m.se[2] = Swish()
|
|
return self
|
|
|
|
|
|
class EMA:
|
|
def __init__(self, model, decay=0.9999):
|
|
super().__init__()
|
|
import copy
|
|
self.decay = decay
|
|
self.model = copy.deepcopy(model)
|
|
|
|
self.model.eval()
|
|
|
|
def update_fn(self, model, fn):
|
|
with torch.no_grad():
|
|
e_std = self.model.state_dict().values()
|
|
m_std = model.module.state_dict().values()
|
|
for e, m in zip(e_std, m_std):
|
|
e.copy_(fn(e, m))
|
|
|
|
def update(self, model):
|
|
self.update_fn(model, fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
|
|
|
|
|
class StepLR:
|
|
def __init__(self, optimizer):
|
|
self.optimizer = optimizer
|
|
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group.setdefault('initial_lr', param_group['lr'])
|
|
|
|
self.base_values = [param_group['initial_lr'] for param_group in self.optimizer.param_groups]
|
|
self.update_groups(self.base_values)
|
|
|
|
self.decay_rate = 0.97
|
|
self.decay_epochs = 2.4
|
|
self.warmup_epochs = 3.0
|
|
self.warmup_lr_init = 1e-6
|
|
|
|
self.warmup_steps = [(v - self.warmup_lr_init) / self.warmup_epochs for v in self.base_values]
|
|
self.update_groups(self.warmup_lr_init)
|
|
|
|
def step(self, epoch: int) -> None:
|
|
if epoch < self.warmup_epochs:
|
|
values = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
|
|
else:
|
|
values = [v * (self.decay_rate ** (epoch // self.decay_epochs)) for v in self.base_values]
|
|
if values is not None:
|
|
self.update_groups(values)
|
|
|
|
def update_groups(self, values):
|
|
if not isinstance(values, (list, tuple)):
|
|
values = [values] * len(self.optimizer.param_groups)
|
|
for param_group, value in zip(self.optimizer.param_groups, values):
|
|
param_group['lr'] = value
|
|
|
|
|
|
class RMSprop(torch.optim.Optimizer):
|
|
def __init__(self, params,
|
|
lr=1e-2, alpha=0.9, eps=1e-3, weight_decay=0, momentum=0.9,
|
|
centered=False, decoupled_decay=False, lr_in_momentum=True):
|
|
|
|
defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum,
|
|
centered=centered, decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
|
super().__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
for param_group in self.param_groups:
|
|
param_group.setdefault('momentum', 0)
|
|
param_group.setdefault('centered', False)
|
|
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for param_group in self.param_groups:
|
|
for param in param_group['params']:
|
|
if param.grad is None:
|
|
continue
|
|
grad = param.grad.data
|
|
if grad.is_sparse:
|
|
raise RuntimeError('Optimizer does not support sparse gradients')
|
|
state = self.state[param]
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
state['square_avg'] = torch.ones_like(param.data)
|
|
if param_group['momentum'] > 0:
|
|
state['momentum_buffer'] = torch.zeros_like(param.data)
|
|
if param_group['centered']:
|
|
state['grad_avg'] = torch.zeros_like(param.data)
|
|
|
|
square_avg = state['square_avg']
|
|
one_minus_alpha = 1. - param_group['alpha']
|
|
|
|
state['step'] += 1
|
|
|
|
if param_group['weight_decay'] != 0:
|
|
if 'decoupled_decay' in param_group and param_group['decoupled_decay']:
|
|
param.data.add_(param.data, alpha=-param_group['weight_decay'])
|
|
else:
|
|
grad = grad.add(param.data, alpha=param_group['weight_decay'])
|
|
|
|
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
|
|
|
if param_group['centered']:
|
|
grad_avg = state['grad_avg']
|
|
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
|
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(param_group['eps']).sqrt_()
|
|
else:
|
|
avg = square_avg.add(param_group['eps']).sqrt_()
|
|
|
|
if param_group['momentum'] > 0:
|
|
buf = state['momentum_buffer']
|
|
if 'lr_in_momentum' in param_group and param_group['lr_in_momentum']:
|
|
buf.mul_(param_group['momentum']).addcdiv_(grad, avg, value=param_group['lr'])
|
|
param.data.add_(-buf)
|
|
else:
|
|
buf.mul_(param_group['momentum']).addcdiv_(grad, avg)
|
|
param.data.add_(-param_group['lr'], buf)
|
|
else:
|
|
param.data.addcdiv_(grad, avg, value=-param_group['lr'])
|
|
|
|
return loss
|
|
|
|
|
|
class PolyLoss(torch.nn.Module):
|
|
"""
|
|
PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions
|
|
"""
|
|
|
|
def __init__(self, epsilon=2.0):
|
|
super().__init__()
|
|
self.epsilon = epsilon
|
|
|
|
def forward(self, outputs, targets):
|
|
ce = cross_entropy(outputs, targets)
|
|
pt = one_hot(targets, outputs.size()[1]) * softmax(outputs, 1)
|
|
|
|
return (ce + self.epsilon * (1.0 - pt.sum(dim=1))).mean()
|
|
|
|
|
|
class CrossEntropyLoss(torch.nn.Module):
|
|
"""
|
|
NLL Loss with label smoothing.
|
|
"""
|
|
|
|
def __init__(self, epsilon=0.1):
|
|
super().__init__()
|
|
self.epsilon = epsilon
|
|
self.softmax = torch.nn.LogSoftmax(dim=-1)
|
|
|
|
def forward(self, x, target):
|
|
prob = self.softmax(x)
|
|
mean = -prob.mean(dim=-1)
|
|
nll_loss = -prob.gather(dim=-1, index=target.unsqueeze(1))
|
|
nll_loss = nll_loss.squeeze(1)
|
|
return ((1. - self.epsilon) * nll_loss + self.epsilon * mean).mean()
|