179 lines
7.6 KiB
Python
179 lines
7.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
from typing import Dict, List
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from detectron2.modeling import RPN_HEAD_REGISTRY
|
||
from detectron2.layers import ShapeSpec
|
||
|
||
from detectron2.modeling.proposal_generator import build_rpn_head
|
||
from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY
|
||
from detectron2.modeling.anchor_generator import build_anchor_generator
|
||
from .box_regression import BUABox2BoxTransform
|
||
from detectron2.modeling.matcher import Matcher
|
||
from .rpn_outputs import BUARPNOutputs, find_top_bua_rpn_proposals
|
||
|
||
import copy
|
||
|
||
@RPN_HEAD_REGISTRY.register()
|
||
class StandardBUARPNHead(nn.Module):
|
||
"""
|
||
RPN classification and regression heads. Uses a 3x3 conv to produce a shared
|
||
hidden state from which one 1x1 conv predicts objectness logits for each anchor
|
||
and a second 1x1 conv predicts bounding-box deltas specifying how to deform
|
||
each anchor into an object proposal.
|
||
"""
|
||
|
||
def __init__(self, cfg, input_shape: List[ShapeSpec]):
|
||
super().__init__()
|
||
|
||
# Standard RPN is shared across levels:
|
||
out_channels = cfg.MODEL.BUA.RPN.CONV_OUT_CHANNELS
|
||
|
||
in_channels = [s.channels for s in input_shape]
|
||
assert len(set(in_channels)) == 1, "Each level must have the same channel!"
|
||
in_channels = in_channels[0]
|
||
|
||
# RPNHead should take the same input as anchor generator
|
||
# NOTE: it assumes that creating an anchor generator does not have unwanted side effect.
|
||
anchor_generator = build_anchor_generator(cfg, input_shape)
|
||
num_cell_anchors = anchor_generator.num_cell_anchors
|
||
box_dim = anchor_generator.box_dim
|
||
assert (
|
||
len(set(num_cell_anchors)) == 1
|
||
), "Each level must have the same number of cell anchors"
|
||
num_cell_anchors = num_cell_anchors[0]
|
||
|
||
# 3x3 conv for the hidden representation
|
||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||
# 1x1 conv for predicting objectness logits
|
||
self.objectness_logits = nn.Conv2d(out_channels, num_cell_anchors * 2, kernel_size=1, stride=1)
|
||
# 1x1 conv for predicting box2box transform deltas
|
||
self.anchor_deltas = nn.Conv2d(
|
||
out_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1
|
||
)
|
||
|
||
for l in [self.conv, self.objectness_logits, self.anchor_deltas]:
|
||
nn.init.normal_(l.weight, std=0.01)
|
||
nn.init.constant_(l.bias, 0)
|
||
|
||
def forward(self, features):
|
||
"""
|
||
Args:
|
||
features (list[Tensor]): list of feature maps
|
||
"""
|
||
pred_objectness_logits = []
|
||
pred_anchor_deltas = []
|
||
for x in features:
|
||
t = F.relu(self.conv(x))
|
||
pred_objectness_logits.append(self.objectness_logits(t))
|
||
pred_anchor_deltas.append(self.anchor_deltas(t))
|
||
return pred_objectness_logits, pred_anchor_deltas
|
||
|
||
@PROPOSAL_GENERATOR_REGISTRY.register()
|
||
class BUARPN(nn.Module):
|
||
"""
|
||
Region Proposal Network, introduced by the Faster R-CNN paper.
|
||
"""
|
||
|
||
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
||
super().__init__()
|
||
|
||
# fmt: off
|
||
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
|
||
self.in_features = cfg.MODEL.RPN.IN_FEATURES
|
||
self.nms_thresh = cfg.MODEL.RPN.NMS_THRESH
|
||
self.batch_size_per_image = cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE
|
||
self.positive_fraction = cfg.MODEL.RPN.POSITIVE_FRACTION
|
||
self.smooth_l1_beta = cfg.MODEL.RPN.SMOOTH_L1_BETA
|
||
self.loss_weight = cfg.MODEL.RPN.LOSS_WEIGHT
|
||
# fmt: on
|
||
|
||
# Map from self.training state to train/test settings
|
||
self.pre_nms_topk = {
|
||
True: cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN,
|
||
False: cfg.MODEL.RPN.PRE_NMS_TOPK_TEST,
|
||
}
|
||
self.post_nms_topk = {
|
||
True: cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN,
|
||
False: cfg.MODEL.RPN.POST_NMS_TOPK_TEST,
|
||
}
|
||
self.boundary_threshold = cfg.MODEL.RPN.BOUNDARY_THRESH
|
||
|
||
self.anchor_generator = build_anchor_generator(
|
||
cfg, [input_shape[f] for f in self.in_features]
|
||
)
|
||
self.box2box_transform = BUABox2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
|
||
self.anchor_matcher = Matcher(
|
||
cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True
|
||
)
|
||
self.rpn_head = build_rpn_head(cfg, [input_shape[f] for f in self.in_features])
|
||
|
||
def forward(self, images, features, gt_instances=None):
|
||
"""
|
||
Args:
|
||
images (ImageList): input images of length `N`
|
||
features (dict[str: Tensor]): input data as a mapping from feature
|
||
map name to tensor. Axis 0 represents the number of images `N` in
|
||
the input data; axes 1-3 are channels, height, and width, which may
|
||
vary between feature maps (e.g., if a feature pyramid is used).
|
||
gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
|
||
Each `Instances` stores ground-truth instances for the corresponding image.
|
||
|
||
Returns:
|
||
proposals: list[Instances] or None
|
||
loss: dict[Tensor]
|
||
"""
|
||
gt_boxes = [x.gt_boxes for x in gt_instances] if gt_instances is not None else None
|
||
del gt_instances
|
||
features = [features[f] for f in self.in_features]
|
||
pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
|
||
anchors_in_image = self.anchor_generator(features)
|
||
anchors = [copy.deepcopy(anchors_in_image) for _ in range(len(features[0]))]
|
||
# TODO: The anchors only depend on the feature map shape; there's probably
|
||
# an opportunity for some optimizations (e.g., caching anchors).
|
||
outputs = BUARPNOutputs(
|
||
self.box2box_transform,
|
||
self.anchor_matcher,
|
||
self.batch_size_per_image,
|
||
self.positive_fraction,
|
||
images,
|
||
pred_objectness_logits,
|
||
pred_anchor_deltas,
|
||
anchors,
|
||
self.boundary_threshold,
|
||
gt_boxes,
|
||
self.smooth_l1_beta,
|
||
)
|
||
|
||
if self.training:
|
||
losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
|
||
else:
|
||
losses = {}
|
||
|
||
with torch.no_grad():
|
||
# Find the top proposals by applying NMS and removing boxes that
|
||
# are too small. The proposals are treated as fixed for approximate
|
||
# joint training with roi heads. This approach ignores the derivative
|
||
# w.r.t. the proposal boxes’ coordinates that are also network
|
||
# responses, so is approximate.
|
||
proposals = find_top_bua_rpn_proposals(
|
||
outputs.predict_proposals(),
|
||
outputs.predict_objectness_logits(),
|
||
images,
|
||
self.nms_thresh,
|
||
self.pre_nms_topk[self.training],
|
||
self.post_nms_topk[self.training],
|
||
self.min_box_side_len,
|
||
self.training,
|
||
)
|
||
# For RPN-only models, the proposals are the final output and we return them in
|
||
# high-to-low confidence order.
|
||
# For end-to-end models, the RPN proposals are an intermediate state
|
||
# and this sorting is actually not needed. But the cost is negligible.
|
||
# inds = [p.objectness_logits.sort(descending=True)[1] for p in proposals]
|
||
# proposals = [p[ind] for p, ind in zip(proposals, inds)]
|
||
|
||
return proposals, losses |