Graduation_Project/LHL/dataloader/detection_utils.py

85 lines
2.9 KiB
Python
Raw Permalink Normal View History

2024-06-25 11:50:04 +08:00
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Common data processing utilities that are used in a
typical object detection data pipeline.
"""
import torch
from detectron2.structures import (
Boxes,
BoxMode,
Instances,
)
def transform_instance_annotations(
annotation, transforms, image_size, *, keypoint_hflip_indices=None
):
"""
Apply transforms to box, segmentation and keypoints annotations of a single instance.
It will use `transforms.apply_box` for the box, and
`transforms.apply_coords` for segmentation polygons & keypoints.
If you need anything more specially designed for each data structure,
you'll need to implement your own version of this function or the transforms.
Args:
annotation (dict): dict of instance annotations for a single instance.
It will be modified in-place.
transforms (TransformList):
image_size (tuple): the height, width of the transformed image
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
Returns:
dict:
the same input dict with fields "bbox", "segmentation", "keypoints"
transformed according to `transforms`.
The "bbox_mode" field will be set to XYXY_ABS.
"""
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
# Note that bbox is 1d (per-instance bounding box)
annotation["bbox"] = transforms.apply_box([bbox])[0]
annotation["bbox_mode"] = BoxMode.XYXY_ABS
if "attributes" in annotation:
annotation["attributes"] = annotation["attributes"]
return annotation
def annotations_to_instances(annos, image_size, mask_format="polygon"):
"""
Create an :class:`Instances` object used by the models,
from instance annotations in the dataset dict.
Args:
annos (list[dict]): a list of instance annotations in one image, each
element for one instance.
image_size (tuple): height, width
Returns:
Instances:
It will contain fields "gt_boxes", "gt_classes",
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
This is the format that builtin models expect.
"""
boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
target = Instances(image_size)
boxes = target.gt_boxes = Boxes(boxes)
boxes.clip(image_size)
classes = [obj["category_id"] for obj in annos]
classes = torch.tensor(classes, dtype=torch.int64)
target.gt_classes = classes
# attributes = [obj["attributes"] for obj in annos]
attributes = []
for obj in annos:
if "attributes" in obj.keys():
attributes.append(obj["attributes"])
else:
attributes.append([-1]*16)
attributes = torch.tensor(attributes, dtype=torch.int64)
target.gt_attributes = attributes
return target