165 lines
6.3 KiB
Python
165 lines
6.3 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
import copy
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import cv2
|
||
|
|
||
|
from detectron2.data import detection_utils as utils
|
||
|
from detectron2.data import transforms as T
|
||
|
|
||
|
from .transform_gen import ResizeShortestEdge
|
||
|
from .detection_utils import annotations_to_instances
|
||
|
|
||
|
"""
|
||
|
This file contains the default mapping that's applied to "dataset dicts".
|
||
|
"""
|
||
|
|
||
|
__all__ = ["DatasetMapper"]
|
||
|
|
||
|
def build_transform_gen(cfg, is_train):
|
||
|
"""
|
||
|
Create a list of :class:`TransformGen` from config.
|
||
|
Now it includes resizing and flipping.
|
||
|
|
||
|
Returns:
|
||
|
list[TransformGen]
|
||
|
"""
|
||
|
if is_train:
|
||
|
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
||
|
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
||
|
else:
|
||
|
min_size = cfg.INPUT.MIN_SIZE_TEST
|
||
|
max_size = cfg.INPUT.MAX_SIZE_TEST
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
tfm_gens = []
|
||
|
tfm_gens.append(ResizeShortestEdge(min_size, max_size, cfg.MODEL.PIXEL_MEAN))
|
||
|
if is_train:
|
||
|
logger.info("TransformGens used in training: " + str(tfm_gens))
|
||
|
return tfm_gens
|
||
|
|
||
|
class DatasetMapper:
|
||
|
"""
|
||
|
A callable which takes a dataset dict in Detectron2 Dataset format,
|
||
|
and map it into a format used by the model.
|
||
|
|
||
|
This is the default callable to be used to map your dataset dict into training data.
|
||
|
You may need to follow it to implement your own one for customized logic.
|
||
|
|
||
|
The callable currently does the following:
|
||
|
1. Read the image from "file_name"
|
||
|
2. Applies cropping/geometric transforms to the image and annotations
|
||
|
3. Prepare data and annotations to Tensor and :class:`Instances`
|
||
|
"""
|
||
|
|
||
|
def __init__(self, cfg, is_train=True):
|
||
|
if cfg.INPUT.CROP.ENABLED and is_train:
|
||
|
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
|
||
|
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
|
||
|
else:
|
||
|
self.crop_gen = None
|
||
|
|
||
|
self.tfm_gens = build_transform_gen(cfg, is_train)
|
||
|
|
||
|
# fmt: off
|
||
|
self.img_format = cfg.INPUT.FORMAT
|
||
|
self.mask_on = cfg.MODEL.MASK_ON
|
||
|
self.mask_format = cfg.INPUT.MASK_FORMAT
|
||
|
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
||
|
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
|
||
|
# fmt: on
|
||
|
if self.keypoint_on and is_train:
|
||
|
# Flip only makes sense in training
|
||
|
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
||
|
else:
|
||
|
self.keypoint_hflip_indices = None
|
||
|
|
||
|
if self.load_proposals:
|
||
|
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
|
||
|
self.proposal_topk = (
|
||
|
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
||
|
if is_train
|
||
|
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
||
|
)
|
||
|
self.is_train = is_train
|
||
|
|
||
|
def __call__(self, dataset_dict):
|
||
|
"""
|
||
|
Args:
|
||
|
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
||
|
|
||
|
Returns:
|
||
|
dict: a format that builtin models in detectron2 accept
|
||
|
"""
|
||
|
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
||
|
# USER: Write your own image loading if it's not from a file
|
||
|
# image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
||
|
image = cv2.imread(dataset_dict["file_name"])
|
||
|
h, w = image.shape[:2]
|
||
|
# utils.check_image_size(dataset_dict, image)
|
||
|
|
||
|
if "annotations" not in dataset_dict:
|
||
|
image, transforms = T.apply_transform_gens(
|
||
|
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
|
||
|
)
|
||
|
else:
|
||
|
# Crop around an instance if there are instances in the image.
|
||
|
# USER: Remove if you don't use cropping
|
||
|
if self.crop_gen:
|
||
|
crop_tfm = utils.gen_crop_transform_with_instance(
|
||
|
self.crop_gen.get_crop_size(image.shape[:2]),
|
||
|
image.shape[:2],
|
||
|
np.random.choice(dataset_dict["annotations"]),
|
||
|
)
|
||
|
image = crop_tfm.apply_image(image)
|
||
|
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
||
|
if self.crop_gen:
|
||
|
transforms = crop_tfm + transforms
|
||
|
|
||
|
image_shape = image.shape[:2] # h, w
|
||
|
|
||
|
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
||
|
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
||
|
# Therefore it's important to use torch.Tensor.
|
||
|
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
|
||
|
dataset_dict["im_scale"] = float(image_shape[0])/ float(h)
|
||
|
# Can use uint8 if it turns out to be slow some day
|
||
|
|
||
|
# USER: Remove if you don't use pre-computed proposals.
|
||
|
if self.load_proposals:
|
||
|
utils.transform_proposals(
|
||
|
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
|
||
|
)
|
||
|
|
||
|
if not self.is_train:
|
||
|
dataset_dict.pop("annotations", None)
|
||
|
dataset_dict.pop("sem_seg_file_name", None)
|
||
|
return dataset_dict
|
||
|
|
||
|
if "annotations" in dataset_dict:
|
||
|
# USER: Modify this if you want to keep them for some reason.
|
||
|
for anno in dataset_dict["annotations"]:
|
||
|
if not self.mask_on:
|
||
|
anno.pop("segmentation", None)
|
||
|
if not self.keypoint_on:
|
||
|
anno.pop("keypoints", None)
|
||
|
|
||
|
# USER: Implement additional transformations if you have other types of data
|
||
|
annos = [
|
||
|
utils.transform_instance_annotations(
|
||
|
obj, transforms, image_shape
|
||
|
)
|
||
|
for obj in dataset_dict.pop("annotations")
|
||
|
if obj.get("iscrowd", 0) == 0
|
||
|
]
|
||
|
instances = annotations_to_instances(
|
||
|
annos, image_shape, mask_format=self.mask_format
|
||
|
)
|
||
|
# Create a tight bounding box from masks, useful when image is cropped
|
||
|
if self.crop_gen and instances.has("gt_masks"):
|
||
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
||
|
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
||
|
|
||
|
return dataset_dict
|