513 lines
22 KiB
Python
513 lines
22 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
import contextlib
|
|
import json
|
|
from collections import defaultdict
|
|
from itertools import repeat
|
|
from multiprocessing.pool import ThreadPool
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import ConcatDataset
|
|
|
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
|
from ultralytics.utils.ops import resample_segments
|
|
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
|
|
|
from .augment import (
|
|
Compose,
|
|
Format,
|
|
Instances,
|
|
LetterBox,
|
|
RandomLoadText,
|
|
classify_augmentations,
|
|
classify_transforms,
|
|
v8_transforms,
|
|
)
|
|
from .base import BaseDataset
|
|
from .utils import (
|
|
HELP_URL,
|
|
LOGGER,
|
|
get_hash,
|
|
img2label_paths,
|
|
load_dataset_cache_file,
|
|
save_dataset_cache_file,
|
|
verify_image,
|
|
verify_image_label,
|
|
)
|
|
|
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
|
DATASET_CACHE_VERSION = "1.0.3"
|
|
|
|
|
|
class YOLODataset(BaseDataset):
|
|
"""
|
|
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
|
|
Args:
|
|
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
|
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
|
|
|
Returns:
|
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
|
"""
|
|
|
|
def __init__(self, *args, data=None, task="detect", **kwargs):
|
|
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
|
self.use_segments = task == "segment"
|
|
self.use_keypoints = task == "pose"
|
|
self.use_obb = task == "obb"
|
|
self.data = data
|
|
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def cache_labels(self, path=Path("./labels.cache")):
|
|
"""
|
|
Cache dataset labels, check images and read shapes.
|
|
|
|
Args:
|
|
path (Path): Path where to save the cache file. Default is Path('./labels.cache').
|
|
|
|
Returns:
|
|
(dict): labels.
|
|
"""
|
|
x = {"labels": []}
|
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
|
total = len(self.im_files)
|
|
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
|
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
|
raise ValueError(
|
|
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
|
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
|
)
|
|
with ThreadPool(NUM_THREADS) as pool:
|
|
results = pool.imap(
|
|
func=verify_image_label,
|
|
iterable=zip(
|
|
self.im_files,
|
|
self.label_files,
|
|
repeat(self.prefix),
|
|
repeat(self.use_keypoints),
|
|
repeat(len(self.data["names"])),
|
|
repeat(nkpt),
|
|
repeat(ndim),
|
|
),
|
|
)
|
|
pbar = TQDM(results, desc=desc, total=total)
|
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
|
nm += nm_f
|
|
nf += nf_f
|
|
ne += ne_f
|
|
nc += nc_f
|
|
if im_file:
|
|
x["labels"].append(
|
|
{
|
|
"im_file": im_file,
|
|
"shape": shape,
|
|
"cls": lb[:, 0:1], # n, 1
|
|
"bboxes": lb[:, 1:], # n, 4
|
|
"segments": segments,
|
|
"keypoints": keypoint,
|
|
"normalized": True,
|
|
"bbox_format": "xywh",
|
|
}
|
|
)
|
|
if msg:
|
|
msgs.append(msg)
|
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
|
pbar.close()
|
|
|
|
if msgs:
|
|
LOGGER.info("\n".join(msgs))
|
|
if nf == 0:
|
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
|
x["hash"] = get_hash(self.label_files + self.im_files)
|
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
|
x["msgs"] = msgs # warnings
|
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
return x
|
|
|
|
def get_labels(self):
|
|
"""Returns dictionary of labels for YOLO training."""
|
|
self.label_files = img2label_paths(self.im_files)
|
|
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
|
try:
|
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
|
except (FileNotFoundError, AssertionError, AttributeError):
|
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
|
|
|
# Display cache
|
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
|
if exists and LOCAL_RANK in {-1, 0}:
|
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
|
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
|
if cache["msgs"]:
|
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
|
|
# Read cache
|
|
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
|
labels = cache["labels"]
|
|
if not labels:
|
|
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
|
|
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
|
|
|
# Check if the dataset is all boxes or all segments
|
|
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
|
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
|
if len_segments and len_boxes != len_segments:
|
|
LOGGER.warning(
|
|
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
|
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
|
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
|
)
|
|
for lb in labels:
|
|
lb["segments"] = []
|
|
if len_cls == 0:
|
|
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
|
|
return labels
|
|
|
|
def build_transforms(self, hyp=None):
|
|
"""Builds and appends transforms to the list."""
|
|
if self.augment:
|
|
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
|
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
|
transforms = v8_transforms(self, self.imgsz, hyp)
|
|
else:
|
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
|
transforms.append(
|
|
Format(
|
|
bbox_format="xywh",
|
|
normalize=True,
|
|
return_mask=self.use_segments,
|
|
return_keypoint=self.use_keypoints,
|
|
return_obb=self.use_obb,
|
|
batch_idx=True,
|
|
mask_ratio=hyp.mask_ratio,
|
|
mask_overlap=hyp.overlap_mask,
|
|
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
|
|
)
|
|
)
|
|
return transforms
|
|
|
|
def close_mosaic(self, hyp):
|
|
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
|
|
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
|
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
|
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
|
self.transforms = self.build_transforms(hyp)
|
|
|
|
def update_labels_info(self, label):
|
|
"""
|
|
Custom your label format here.
|
|
|
|
Note:
|
|
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
|
Can also support classification and semantic segmentation by adding or removing dict keys there.
|
|
"""
|
|
bboxes = label.pop("bboxes")
|
|
segments = label.pop("segments", [])
|
|
keypoints = label.pop("keypoints", None)
|
|
bbox_format = label.pop("bbox_format")
|
|
normalized = label.pop("normalized")
|
|
|
|
# NOTE: do NOT resample oriented boxes
|
|
segment_resamples = 100 if self.use_obb else 1000
|
|
if len(segments) > 0:
|
|
# list[np.array(1000, 2)] * num_samples
|
|
# (N, 1000, 2)
|
|
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
|
else:
|
|
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
|
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
|
return label
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
"""Collates data samples into batches."""
|
|
new_batch = {}
|
|
keys = batch[0].keys()
|
|
values = list(zip(*[list(b.values()) for b in batch]))
|
|
for i, k in enumerate(keys):
|
|
value = values[i]
|
|
if k == "img":
|
|
value = torch.stack(value, 0)
|
|
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
|
value = torch.cat(value, 0)
|
|
new_batch[k] = value
|
|
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
|
for i in range(len(new_batch["batch_idx"])):
|
|
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
|
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
|
return new_batch
|
|
|
|
|
|
class YOLOMultiModalDataset(YOLODataset):
|
|
"""
|
|
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
|
|
Args:
|
|
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
|
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
|
|
|
Returns:
|
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
|
"""
|
|
|
|
def __init__(self, *args, data=None, task="detect", **kwargs):
|
|
"""Initializes a dataset object for object detection tasks with optional specifications."""
|
|
super().__init__(*args, data=data, task=task, **kwargs)
|
|
|
|
def update_labels_info(self, label):
|
|
"""Add texts information for multi-modal model training."""
|
|
labels = super().update_labels_info(label)
|
|
# NOTE: some categories are concatenated with its synonyms by `/`.
|
|
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
|
return labels
|
|
|
|
def build_transforms(self, hyp=None):
|
|
"""Enhances data transformations with optional text augmentation for multi-modal training."""
|
|
transforms = super().build_transforms(hyp)
|
|
if self.augment:
|
|
# NOTE: hard-coded the args for now.
|
|
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
|
|
return transforms
|
|
|
|
|
|
class GroundingDataset(YOLODataset):
|
|
"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
|
|
|
|
def __init__(self, *args, task="detect", json_file, **kwargs):
|
|
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
|
|
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
|
|
self.json_file = json_file
|
|
super().__init__(*args, task=task, data={}, **kwargs)
|
|
|
|
def get_img_files(self, img_path):
|
|
"""The image files would be read in `get_labels` function, return empty list here."""
|
|
return []
|
|
|
|
def get_labels(self):
|
|
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
|
|
labels = []
|
|
LOGGER.info("Loading annotation file...")
|
|
with open(self.json_file, "r") as f:
|
|
annotations = json.load(f)
|
|
images = {f'{x["id"]:d}': x for x in annotations["images"]}
|
|
img_to_anns = defaultdict(list)
|
|
for ann in annotations["annotations"]:
|
|
img_to_anns[ann["image_id"]].append(ann)
|
|
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
|
|
img = images[f"{img_id:d}"]
|
|
h, w, f = img["height"], img["width"], img["file_name"]
|
|
im_file = Path(self.img_path) / f
|
|
if not im_file.exists():
|
|
continue
|
|
self.im_files.append(str(im_file))
|
|
bboxes = []
|
|
cat2id = {}
|
|
texts = []
|
|
for ann in anns:
|
|
if ann["iscrowd"]:
|
|
continue
|
|
box = np.array(ann["bbox"], dtype=np.float32)
|
|
box[:2] += box[2:] / 2
|
|
box[[0, 2]] /= float(w)
|
|
box[[1, 3]] /= float(h)
|
|
if box[2] <= 0 or box[3] <= 0:
|
|
continue
|
|
|
|
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
|
|
if cat_name not in cat2id:
|
|
cat2id[cat_name] = len(cat2id)
|
|
texts.append([cat_name])
|
|
cls = cat2id[cat_name] # class
|
|
box = [cls] + box.tolist()
|
|
if box not in bboxes:
|
|
bboxes.append(box)
|
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
|
labels.append(
|
|
{
|
|
"im_file": im_file,
|
|
"shape": (h, w),
|
|
"cls": lb[:, 0:1], # n, 1
|
|
"bboxes": lb[:, 1:], # n, 4
|
|
"normalized": True,
|
|
"bbox_format": "xywh",
|
|
"texts": texts,
|
|
}
|
|
)
|
|
return labels
|
|
|
|
def build_transforms(self, hyp=None):
|
|
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
|
|
transforms = super().build_transforms(hyp)
|
|
if self.augment:
|
|
# NOTE: hard-coded the args for now.
|
|
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
|
|
return transforms
|
|
|
|
|
|
class YOLOConcatDataset(ConcatDataset):
|
|
"""
|
|
Dataset as a concatenation of multiple datasets.
|
|
|
|
This class is useful to assemble different existing datasets.
|
|
"""
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
"""Collates data samples into batches."""
|
|
return YOLODataset.collate_fn(batch)
|
|
|
|
|
|
# TODO: support semantic segmentation
|
|
class SemanticDataset(BaseDataset):
|
|
"""
|
|
Semantic Segmentation Dataset.
|
|
|
|
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
|
|
from the BaseDataset class.
|
|
|
|
Note:
|
|
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
|
|
semantic segmentation tasks.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize a SemanticDataset object."""
|
|
super().__init__()
|
|
|
|
|
|
class ClassificationDataset:
|
|
"""
|
|
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
|
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
|
learning models, with optional image transformations and caching mechanisms to speed up training.
|
|
|
|
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
|
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
|
to ensure data integrity and consistency.
|
|
|
|
Attributes:
|
|
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
"""
|
|
|
|
def __init__(self, root, args, augment=False, prefix=""):
|
|
"""
|
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
|
|
Args:
|
|
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
|
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
|
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
|
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
|
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
|
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
|
debugging. Default is an empty string.
|
|
"""
|
|
import torchvision # scope for faster 'import ultralytics'
|
|
|
|
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
|
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
|
|
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
|
|
else:
|
|
self.base = torchvision.datasets.ImageFolder(root=root)
|
|
self.samples = self.base.samples
|
|
self.root = self.base.root
|
|
|
|
# Initialize attributes
|
|
if augment and args.fraction < 1.0: # reduce training fraction
|
|
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
|
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
|
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
|
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
self.samples = self.verify_images() # filter out bad images
|
|
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
self.torch_transforms = (
|
|
classify_augmentations(
|
|
size=args.imgsz,
|
|
scale=scale,
|
|
hflip=args.fliplr,
|
|
vflip=args.flipud,
|
|
erasing=args.erasing,
|
|
auto_augment=args.auto_augment,
|
|
hsv_h=args.hsv_h,
|
|
hsv_s=args.hsv_s,
|
|
hsv_v=args.hsv_v,
|
|
)
|
|
if augment
|
|
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
|
)
|
|
|
|
def __getitem__(self, i):
|
|
"""Returns subset of data and targets corresponding to given indices."""
|
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
|
if self.cache_ram:
|
|
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
|
im = self.samples[i][3] = cv2.imread(f)
|
|
elif self.cache_disk:
|
|
if not fn.exists(): # load npy
|
|
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
|
im = np.load(fn)
|
|
else: # read image
|
|
im = cv2.imread(f) # BGR
|
|
# Convert NumPy array to PIL image
|
|
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
|
sample = self.torch_transforms(im)
|
|
return {"img": sample, "cls": j}
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the total number of samples in the dataset."""
|
|
return len(self.samples)
|
|
|
|
def verify_images(self):
|
|
"""Verify all images in dataset."""
|
|
desc = f"{self.prefix}Scanning {self.root}..."
|
|
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
|
|
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
|
if LOCAL_RANK in {-1, 0}:
|
|
d = f"{desc} {nf} images, {nc} corrupt"
|
|
TQDM(None, desc=d, total=n, initial=n)
|
|
if cache["msgs"]:
|
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
|
return samples
|
|
|
|
# Run scan if *.cache retrieval failed
|
|
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
with ThreadPool(NUM_THREADS) as pool:
|
|
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
|
for sample, nf_f, nc_f, msg in pbar:
|
|
if nf_f:
|
|
samples.append(sample)
|
|
if msg:
|
|
msgs.append(msg)
|
|
nf += nf_f
|
|
nc += nc_f
|
|
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
|
pbar.close()
|
|
if msgs:
|
|
LOGGER.info("\n".join(msgs))
|
|
x["hash"] = get_hash([x[0] for x in self.samples])
|
|
x["results"] = nf, nc, len(samples), samples
|
|
x["msgs"] = msgs # warnings
|
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
return samples
|