brain-tumor_image_classific.../efn.py

342 lines
13 KiB
Python

import torch
from timm.models.layers import DropPath
from torch.nn.functional import cross_entropy, dropout, one_hot, softmax
def init_weight(model):
import math
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out = fan_out // m.groups
torch.nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
if isinstance(m, torch.nn.Linear):
init_range = 1.0 / math.sqrt(m.weight.size()[0])
torch.nn.init.uniform_(m.weight, -init_range, init_range)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
class Conv(torch.nn.Module):
def __init__(self, in_ch, out_ch, activation, k=1, s=1, g=1):
super().__init__()
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, (k - 1) // 2, 1, g, bias=False)
self.norm = torch.nn.BatchNorm2d(out_ch, 0.001, 0.01)
self.relu = activation
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
class SE(torch.nn.Module):
def __init__(self, ch, r):
super().__init__()
self.se = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Conv2d(ch, ch // (4 * r), 1),
torch.nn.SiLU(),
torch.nn.Conv2d(ch // (4 * r), ch, 1),
torch.nn.Sigmoid())
def forward(self, x):
return x * self.se(x)
class Residual(torch.nn.Module):
"""
[https://arxiv.org/pdf/1801.04381.pdf]
"""
def __init__(self, in_ch, out_ch, s, r, dp_rate=0, fused=True):
super().__init__()
identity = torch.nn.Identity()
self.add = s == 1 and in_ch == out_ch
if fused:
features = [Conv(in_ch, r * in_ch, activation=torch.nn.SiLU(), k=3, s=s),
Conv(r * in_ch, out_ch, identity) if r != 1 else identity,
DropPath(dp_rate) if self.add else identity]
else:
features = [Conv(in_ch, r * in_ch, torch.nn.SiLU()) if r != 1 else identity,
Conv(r * in_ch, r * in_ch, torch.nn.SiLU(), 3, s, r * in_ch),
SE(r * in_ch, r), Conv(r * in_ch, out_ch, identity),
DropPath(dp_rate) if self.add else identity]
self.res = torch.nn.Sequential(*features)
def forward(self, x):
return x + self.res(x) if self.add else self.res(x)
class EfficientNet(torch.nn.Module):
"""
efficientnet-v2-s :
num_dep = [2, 4, 4, 6, 9, 15, 0]
filters = [24, 48, 64, 128, 160, 256, 256, 1280]
efficientnet-v2-m :
num_dep = [3, 5, 5, 7, 14, 18, 5]
filters = [24, 48, 80, 160, 176, 304, 512, 1280]
efficientnet-v2-l :
num_dep = [4, 7, 7, 10, 19, 25, 7]
filters = [32, 64, 96, 192, 224, 384, 640, 1280]
"""
def __init__(self, drop_rate=0, num_class=4):
super().__init__()
num_dep = [2, 4, 4, 6, 9, 15, 0]
filters = [24, 48, 64, 128, 160, 256, 256, 1280]
dp_index = 0
dp_rates = [x.item() for x in torch.linspace(0, 0.2, sum(num_dep))]
self.p1 = []
self.p2 = []
self.p3 = []
self.p4 = []
self.p5 = []
# p1/2
for i in range(num_dep[0]):
if i == 0:
self.p1.append(Conv(3, filters[0], torch.nn.SiLU(), 3, 2))
self.p1.append(Residual(filters[0], filters[0], 1, 1, dp_rates[dp_index]))
else:
self.p1.append(Residual(filters[0], filters[0], 1, 1, dp_rates[dp_index]))
dp_index += 1
# p2/4
for i in range(num_dep[1]):
if i == 0:
self.p2.append(Residual(filters[0], filters[1], 2, 4, dp_rates[dp_index]))
else:
self.p2.append(Residual(filters[1], filters[1], 1, 4, dp_rates[dp_index]))
dp_index += 1
# p3/8
for i in range(num_dep[2]):
if i == 0:
self.p3.append(Residual(filters[1], filters[2], 2, 4, dp_rates[dp_index]))
else:
self.p3.append(Residual(filters[2], filters[2], 1, 4, dp_rates[dp_index]))
dp_index += 1
# p4/16
for i in range(num_dep[3]):
if i == 0:
self.p4.append(Residual(filters[2], filters[3], 2, 4, dp_rates[dp_index], False))
else:
self.p4.append(Residual(filters[3], filters[3], 1, 4, dp_rates[dp_index], False))
dp_index += 1
for i in range(num_dep[4]):
if i == 0:
self.p4.append(Residual(filters[3], filters[4], 1, 6, dp_rates[dp_index], False))
else:
self.p4.append(Residual(filters[4], filters[4], 1, 6, dp_rates[dp_index], False))
dp_index += 1
# p5/32
for i in range(num_dep[5]):
if i == 0:
self.p5.append(Residual(filters[4], filters[5], 2, 6, dp_rates[dp_index], False))
else:
self.p5.append(Residual(filters[5], filters[5], 1, 6, dp_rates[dp_index], False))
dp_index += 1
for i in range(num_dep[6]):
if i == 0:
self.p5.append(Residual(filters[5], filters[6], 2, 6, dp_rates[dp_index], False))
else:
self.p5.append(Residual(filters[6], filters[6], 1, 6, dp_rates[dp_index], False))
dp_index += 1
self.p1 = torch.nn.Sequential(*self.p1)
self.p2 = torch.nn.Sequential(*self.p2)
self.p3 = torch.nn.Sequential(*self.p3)
self.p4 = torch.nn.Sequential(*self.p4)
self.p5 = torch.nn.Sequential(*self.p5)
self.fc1 = torch.nn.Sequential(Conv(filters[6], filters[7], torch.nn.SiLU()),
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten())
self.fc2 = torch.nn.Linear(filters[7], num_class)
self.drop_rate = drop_rate
init_weight(self)
def forward(self, x):
x = self.p1(x)
x = self.p2(x)
x = self.p3(x)
x = self.p4(x)
x = self.p5(x)
x = self.fc1(x)
if self.drop_rate > 0:
x = dropout(x, self.drop_rate, self.training)
return self.fc2(x)
def export(self):
from timm.models.layers import Swish
for m in self.modules():
if type(m) is Conv and hasattr(m, 'relu'):
if isinstance(m.relu, torch.nn.SiLU):
m.relu = Swish()
if type(m) is SE:
if isinstance(m.se[2], torch.nn.SiLU):
m.se[2] = Swish()
return self
class EMA:
def __init__(self, model, decay=0.9999):
super().__init__()
import copy
self.decay = decay
self.model = copy.deepcopy(model)
self.model.eval()
def update_fn(self, model, fn):
with torch.no_grad():
e_std = self.model.state_dict().values()
m_std = model.module.state_dict().values()
for e, m in zip(e_std, m_std):
e.copy_(fn(e, m))
def update(self, model):
self.update_fn(model, fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
class StepLR:
def __init__(self, optimizer):
self.optimizer = optimizer
for param_group in self.optimizer.param_groups:
param_group.setdefault('initial_lr', param_group['lr'])
self.base_values = [param_group['initial_lr'] for param_group in self.optimizer.param_groups]
self.update_groups(self.base_values)
self.decay_rate = 0.97
self.decay_epochs = 2.4
self.warmup_epochs = 3.0
self.warmup_lr_init = 1e-6
self.warmup_steps = [(v - self.warmup_lr_init) / self.warmup_epochs for v in self.base_values]
self.update_groups(self.warmup_lr_init)
def step(self, epoch: int) -> None:
if epoch < self.warmup_epochs:
values = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
else:
values = [v * (self.decay_rate ** (epoch // self.decay_epochs)) for v in self.base_values]
if values is not None:
self.update_groups(values)
def update_groups(self, values):
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
param_group['lr'] = value
class RMSprop(torch.optim.Optimizer):
def __init__(self, params,
lr=1e-2, alpha=0.9, eps=1e-3, weight_decay=0, momentum=0.9,
centered=False, decoupled_decay=False, lr_in_momentum=True):
defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum,
centered=centered, decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for param_group in self.param_groups:
param_group.setdefault('momentum', 0)
param_group.setdefault('centered', False)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for param_group in self.param_groups:
for param in param_group['params']:
if param.grad is None:
continue
grad = param.grad.data
if grad.is_sparse:
raise RuntimeError('Optimizer does not support sparse gradients')
state = self.state[param]
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.ones_like(param.data)
if param_group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(param.data)
if param_group['centered']:
state['grad_avg'] = torch.zeros_like(param.data)
square_avg = state['square_avg']
one_minus_alpha = 1. - param_group['alpha']
state['step'] += 1
if param_group['weight_decay'] != 0:
if 'decoupled_decay' in param_group and param_group['decoupled_decay']:
param.data.add_(param.data, alpha=-param_group['weight_decay'])
else:
grad = grad.add(param.data, alpha=param_group['weight_decay'])
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
if param_group['centered']:
grad_avg = state['grad_avg']
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(param_group['eps']).sqrt_()
else:
avg = square_avg.add(param_group['eps']).sqrt_()
if param_group['momentum'] > 0:
buf = state['momentum_buffer']
if 'lr_in_momentum' in param_group and param_group['lr_in_momentum']:
buf.mul_(param_group['momentum']).addcdiv_(grad, avg, value=param_group['lr'])
param.data.add_(-buf)
else:
buf.mul_(param_group['momentum']).addcdiv_(grad, avg)
param.data.add_(-param_group['lr'], buf)
else:
param.data.addcdiv_(grad, avg, value=-param_group['lr'])
return loss
class PolyLoss(torch.nn.Module):
"""
PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions
"""
def __init__(self, epsilon=2.0):
super().__init__()
self.epsilon = epsilon
def forward(self, outputs, targets):
ce = cross_entropy(outputs, targets)
pt = one_hot(targets, outputs.size()[1]) * softmax(outputs, 1)
return (ce + self.epsilon * (1.0 - pt.sum(dim=1))).mean()
class CrossEntropyLoss(torch.nn.Module):
"""
NLL Loss with label smoothing.
"""
def __init__(self, epsilon=0.1):
super().__init__()
self.epsilon = epsilon
self.softmax = torch.nn.LogSoftmax(dim=-1)
def forward(self, x, target):
prob = self.softmax(x)
mean = -prob.mean(dim=-1)
nll_loss = -prob.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
return ((1. - self.epsilon) * nll_loss + self.epsilon * mean).mean()