Graduation_Project/LHL/dataloader/dataset_mapper.py

165 lines
6.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
import cv2
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from .transform_gen import ResizeShortestEdge
from .detection_utils import annotations_to_instances
"""
This file contains the default mapping that's applied to "dataset dicts".
"""
__all__ = ["DatasetMapper"]
def build_transform_gen(cfg, is_train):
"""
Create a list of :class:`TransformGen` from config.
Now it includes resizing and flipping.
Returns:
list[TransformGen]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
logger = logging.getLogger(__name__)
tfm_gens = []
tfm_gens.append(ResizeShortestEdge(min_size, max_size, cfg.MODEL.PIXEL_MEAN))
if is_train:
logger.info("TransformGens used in training: " + str(tfm_gens))
return tfm_gens
class DatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data.
You may need to follow it to implement your own one for customized logic.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies cropping/geometric transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
"""
def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
else:
self.crop_gen = None
self.tfm_gens = build_transform_gen(cfg, is_train)
# fmt: off
self.img_format = cfg.INPUT.FORMAT
self.mask_on = cfg.MODEL.MASK_ON
self.mask_format = cfg.INPUT.MASK_FORMAT
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
# fmt: on
if self.keypoint_on and is_train:
# Flip only makes sense in training
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
else:
self.keypoint_hflip_indices = None
if self.load_proposals:
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
self.proposal_topk = (
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
if is_train
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
)
self.is_train = is_train
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# USER: Write your own image loading if it's not from a file
# image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
image = cv2.imread(dataset_dict["file_name"])
h, w = image.shape[:2]
# utils.check_image_size(dataset_dict, image)
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
image = crop_tfm.apply_image(image)
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
dataset_dict["im_scale"] = float(image_shape[0])/ float(h)
# Can use uint8 if it turns out to be slow some day
# USER: Remove if you don't use pre-computed proposals.
if self.load_proposals:
utils.transform_proposals(
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
)
if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
if not self.mask_on:
anno.pop("segmentation", None)
if not self.keypoint_on:
anno.pop("keypoints", None)
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(
obj, transforms, image_shape
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = annotations_to_instances(
annos, image_shape, mask_format=self.mask_format
)
# Create a tight bounding box from masks, useful when image is cropped
if self.crop_gen and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict