276 lines
9.3 KiB
Python
276 lines
9.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
import fvcore.nn.weight_init as weight_init
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
|
||
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm, BatchNorm2d
|
||
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, make_stage
|
||
from detectron2.modeling.backbone.resnet import BottleneckBlock, DeformBottleneckBlock, ResNetBlockBase
|
||
|
||
from .layers.wrappers import Conv2dv2
|
||
|
||
__all__ = ["BUABasicStem", "BUABasicStemv2", "build_bua_resnet_backbone"]
|
||
|
||
class BUABasicStem(nn.Module):
|
||
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
||
"""
|
||
Args:
|
||
norm (str or callable): a callable that takes the number of
|
||
channels and return a `nn.Module`, or a pre-defined string
|
||
(one of {"FrozenBN", "BN", "GN"}).
|
||
"""
|
||
super().__init__()
|
||
self.conv1 = Conv2d(
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size=7,
|
||
stride=2,
|
||
padding=3,
|
||
bias=False,
|
||
norm=get_norm(norm, out_channels),
|
||
)
|
||
weight_init.c2_msra_fill(self.conv1)
|
||
|
||
def forward(self, x):
|
||
x = self.conv1(x)
|
||
x = F.relu_(x)
|
||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||
return x
|
||
|
||
@property
|
||
def out_channels(self):
|
||
return self.conv1.out_channels
|
||
|
||
@property
|
||
def stride(self):
|
||
return 4 # = stride 2 conv -> stride 2 max pool
|
||
|
||
class BUABasicStemv2(nn.Module):
|
||
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
|
||
"""
|
||
Args:
|
||
norm (str or callable): a callable that takes the number of
|
||
channels and return a `nn.Module`, or a pre-defined string
|
||
(one of {"FrozenBN", "BN", "GN"}).
|
||
"""
|
||
super().__init__()
|
||
self.norm = BatchNorm2d(in_channels, eps=2e-5)
|
||
self.conv1 = Conv2d(
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size=7,
|
||
stride=2,
|
||
padding=3,
|
||
bias=False,
|
||
norm=BatchNorm2d(out_channels, eps=2e-5),
|
||
)
|
||
# weight_init.c2_msra_fill(self.norm)
|
||
weight_init.c2_msra_fill(self.conv1)
|
||
|
||
def forward(self, x):
|
||
x = self.norm(x)
|
||
x = self.conv1(x)
|
||
x = F.relu_(x)
|
||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||
return x
|
||
|
||
@property
|
||
def out_channels(self):
|
||
return self.conv1.out_channels
|
||
|
||
@property
|
||
def stride(self):
|
||
return 4 # = stride 2 conv -> stride 2 max pool
|
||
|
||
@BACKBONE_REGISTRY.register()
|
||
def build_bua_resnet_backbone(cfg, input_shape):
|
||
"""
|
||
Create a ResNet instance from config.
|
||
|
||
Returns:
|
||
ResNet: a :class:`ResNet` instance.
|
||
"""
|
||
# need registration of new blocks/stems?
|
||
norm = cfg.MODEL.RESNETS.NORM
|
||
if cfg.MODEL.BUA.RESNET_VERSION == 2:
|
||
stem = BUABasicStemv2(
|
||
in_channels=input_shape.channels,
|
||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||
)
|
||
else:
|
||
stem = BUABasicStem(
|
||
in_channels=input_shape.channels,
|
||
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
|
||
norm=norm,
|
||
)
|
||
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
|
||
|
||
if freeze_at >= 1:
|
||
for p in stem.parameters():
|
||
p.requires_grad = False
|
||
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
|
||
|
||
# fmt: off
|
||
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
|
||
depth = cfg.MODEL.RESNETS.DEPTH
|
||
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
|
||
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
|
||
bottleneck_channels = num_groups * width_per_group
|
||
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
|
||
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
|
||
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
|
||
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
|
||
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
|
||
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
|
||
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
|
||
# fmt: on
|
||
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
||
|
||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
||
|
||
stages = []
|
||
|
||
# Avoid creating variables without gradients
|
||
# It consumes extra memory and may cause allreduce to fail
|
||
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
|
||
max_stage_idx = max(out_stage_idx)
|
||
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
||
dilation = res5_dilation if stage_idx == 5 else 1
|
||
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
||
stage_kargs = {
|
||
"num_blocks": num_blocks_per_stage[idx],
|
||
"first_stride": first_stride,
|
||
"in_channels": in_channels,
|
||
"bottleneck_channels": bottleneck_channels,
|
||
"out_channels": out_channels,
|
||
"num_groups": num_groups,
|
||
"norm": norm,
|
||
"stride_in_1x1": stride_in_1x1,
|
||
"dilation": dilation,
|
||
}
|
||
if deform_on_per_stage[idx]:
|
||
stage_kargs["block_class"] = DeformBottleneckBlock
|
||
stage_kargs["deform_modulated"] = deform_modulated
|
||
stage_kargs["deform_num_groups"] = deform_num_groups
|
||
else:
|
||
stage_kargs["block_class"] = BottleneckBlock if cfg.MODEL.BUA.RESNET_VERSION == 1 else BottleneckBlockv2
|
||
blocks = make_stage(**stage_kargs)
|
||
in_channels = out_channels
|
||
out_channels *= 2
|
||
bottleneck_channels *= 2
|
||
|
||
if freeze_at >= stage_idx:
|
||
for block in blocks:
|
||
block.freeze()
|
||
stages.append(blocks)
|
||
return ResNet(stem, stages, out_features=out_features)
|
||
|
||
class BottleneckBlockv2(ResNetBlockBase):
|
||
def __init__(
|
||
self,
|
||
in_channels,
|
||
out_channels,
|
||
*,
|
||
bottleneck_channels,
|
||
stride=1,
|
||
num_groups=1,
|
||
norm="BN",
|
||
stride_in_1x1=False,
|
||
dilation=1,
|
||
):
|
||
"""
|
||
Args:
|
||
norm (str or callable): a callable that takes the number of
|
||
channels and return a `nn.Module`, or a pre-defined string
|
||
(one of {"FrozenBN", "BN", "GN"}).
|
||
stride_in_1x1 (bool): when stride==2, whether to put stride in the
|
||
first 1x1 convolution or the bottleneck 3x3 convolution.
|
||
"""
|
||
super().__init__(in_channels, out_channels, stride)
|
||
|
||
if in_channels != out_channels:
|
||
self.shortcut = Conv2dv2(
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size=1,
|
||
stride=stride,
|
||
bias=False,
|
||
norm=None,
|
||
)
|
||
else:
|
||
self.shortcut = None
|
||
|
||
# The original MSRA ResNet models have stride in the first 1x1 conv
|
||
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
||
# stride in the 3x3 conv
|
||
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
||
|
||
self.conv1 = Conv2dv2(
|
||
in_channels,
|
||
bottleneck_channels,
|
||
kernel_size=1,
|
||
stride=stride_1x1,
|
||
bias=False,
|
||
norm=None,
|
||
)
|
||
|
||
self.conv2 = Conv2dv2(
|
||
bottleneck_channels,
|
||
bottleneck_channels,
|
||
kernel_size=3,
|
||
stride=stride_3x3,
|
||
padding=1 * dilation,
|
||
bias=False,
|
||
groups=num_groups,
|
||
dilation=dilation,
|
||
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
|
||
activation=F.relu_,
|
||
)
|
||
|
||
self.conv3 = Conv2dv2(
|
||
bottleneck_channels,
|
||
out_channels,
|
||
kernel_size=1,
|
||
bias=False,
|
||
norm=BatchNorm2d(bottleneck_channels, eps=2e-5),
|
||
activation=F.relu_,
|
||
)
|
||
|
||
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
||
if layer is not None: # shortcut can be None
|
||
weight_init.c2_msra_fill(layer)
|
||
|
||
self.norm = BatchNorm2d(in_channels, eps=2e-5)
|
||
|
||
# Zero-initialize the last normalization in each residual branch,
|
||
# so that at the beginning, the residual branch starts with zeros,
|
||
# and each residual block behaves like an identity.
|
||
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
||
# "For BN layers, the learnable scaling coefficient γ is initialized
|
||
# to be 1, except for each residual block's last BN
|
||
# where γ is initialized to be 0."
|
||
|
||
# nn.init.constant_(self.conv3.norm.weight, 0)
|
||
# TODO this somehow hurts performance when training GN models from scratch.
|
||
# Add it as an option when we need to use this code to train a backbone.
|
||
|
||
def forward(self, x):
|
||
x_2 = self.norm(x)
|
||
x_2 = F.relu_(x_2)
|
||
|
||
out = self.conv1(x_2)
|
||
# out = F.relu_(out)
|
||
|
||
out = self.conv2(out)
|
||
# out = F.relu_(out)
|
||
|
||
out = self.conv3(out)
|
||
|
||
if self.shortcut is not None:
|
||
shortcut = self.shortcut(x_2)
|
||
else:
|
||
shortcut = x
|
||
|
||
out += shortcut
|
||
# out = F.relu_(out)
|
||
return out |