Graduation_Project/LHL/models/bua/backbone.py

276 lines
9.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
from torch import nn
import torch.nn.functional as F
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm, BatchNorm2d
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, make_stage
from detectron2.modeling.backbone.resnet import BottleneckBlock, DeformBottleneckBlock, ResNetBlockBase
from .layers.wrappers import Conv2dv2
__all__ = ["BUABasicStem", "BUABasicStemv2", "build_bua_resnet_backbone"]
class BUABasicStem(nn.Module):
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
"""
Args:
norm (str or callable): a callable that takes the number of
channels and return a `nn.Module`, or a pre-defined string
(one of {"FrozenBN", "BN", "GN"}).
"""
super().__init__()
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
norm=get_norm(norm, out_channels),
)
weight_init.c2_msra_fill(self.conv1)
def forward(self, x):
x = self.conv1(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
return x
@property
def out_channels(self):
return self.conv1.out_channels
@property
def stride(self):
return 4 # = stride 2 conv -> stride 2 max pool
class BUABasicStemv2(nn.Module):
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
"""
Args:
norm (str or callable): a callable that takes the number of
channels and return a `nn.Module`, or a pre-defined string
(one of {"FrozenBN", "BN", "GN"}).
"""
super().__init__()
self.norm = BatchNorm2d(in_channels, eps=2e-5)
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
norm=BatchNorm2d(out_channels, eps=2e-5),
)
# weight_init.c2_msra_fill(self.norm)
weight_init.c2_msra_fill(self.conv1)
def forward(self, x):
x = self.norm(x)
x = self.conv1(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
return x
@property
def out_channels(self):
return self.conv1.out_channels
@property
def stride(self):
return 4 # = stride 2 conv -> stride 2 max pool
@BACKBONE_REGISTRY.register()
def build_bua_resnet_backbone(cfg, input_shape):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
if cfg.MODEL.BUA.RESNET_VERSION == 2:
stem = BUABasicStemv2(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
)
else:
stem = BUABasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
if freeze_at >= 1:
for p in stem.parameters():
p.requires_grad = False
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
# fmt: off
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
# fmt: on
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
stages = []
# Avoid creating variables without gradients
# It consumes extra memory and may cause allreduce to fail
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
max_stage_idx = max(out_stage_idx)
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"first_stride": first_stride,
"in_channels": in_channels,
"bottleneck_channels": bottleneck_channels,
"out_channels": out_channels,
"num_groups": num_groups,
"norm": norm,
"stride_in_1x1": stride_in_1x1,
"dilation": dilation,
}
if deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock if cfg.MODEL.BUA.RESNET_VERSION == 1 else BottleneckBlockv2
blocks = make_stage(**stage_kargs)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
if freeze_at >= stage_idx:
for block in blocks:
block.freeze()
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features)
class BottleneckBlockv2(ResNetBlockBase):
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
):
"""
Args:
norm (str or callable): a callable that takes the number of
channels and return a `nn.Module`, or a pre-defined string
(one of {"FrozenBN", "BN", "GN"}).
stride_in_1x1 (bool): when stride==2, whether to put stride in the
first 1x1 convolution or the bottleneck 3x3 convolution.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2dv2(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=None,
)
else:
self.shortcut = None
# The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2dv2(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=None,
)
self.conv2 = Conv2dv2(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
activation=F.relu_,
)
self.conv3 = Conv2dv2(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
activation=F.relu_,
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
self.norm = BatchNorm2d(in_channels, eps=2e-5)
# Zero-initialize the last normalization in each residual branch,
# so that at the beginning, the residual branch starts with zeros,
# and each residual block behaves like an identity.
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "For BN layers, the learnable scaling coefficient γ is initialized
# to be 1, except for each residual block's last BN
# where γ is initialized to be 0."
# nn.init.constant_(self.conv3.norm.weight, 0)
# TODO this somehow hurts performance when training GN models from scratch.
# Add it as an option when we need to use this code to train a backbone.
def forward(self, x):
x_2 = self.norm(x)
x_2 = F.relu_(x_2)
out = self.conv1(x_2)
# out = F.relu_(out)
out = self.conv2(out)
# out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x_2)
else:
shortcut = x
out += shortcut
# out = F.relu_(out)
return out