943 lines
33 KiB
Python
943 lines
33 KiB
Python
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||
|
|
||
|
import contextlib
|
||
|
import inspect
|
||
|
import logging.config
|
||
|
import os
|
||
|
import platform
|
||
|
import re
|
||
|
import subprocess
|
||
|
import sys
|
||
|
import threading
|
||
|
import urllib
|
||
|
import uuid
|
||
|
from pathlib import Path
|
||
|
from types import SimpleNamespace
|
||
|
from typing import Union
|
||
|
|
||
|
import cv2
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import yaml
|
||
|
from tqdm import tqdm as tqdm_original
|
||
|
|
||
|
from ultralytics import __version__
|
||
|
|
||
|
# PyTorch Multi-GPU DDP Constants
|
||
|
RANK = int(os.getenv('RANK', -1))
|
||
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||
|
|
||
|
# Other Constants
|
||
|
FILE = Path(__file__).resolve()
|
||
|
ROOT = FILE.parents[1] # YOLO
|
||
|
ASSETS = ROOT / 'assets' # default images
|
||
|
DEFAULT_CFG_PATH = ROOT / 'cfg/default.yaml'
|
||
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||
|
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||
|
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||
|
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' if VERBOSE else None # tqdm bar format
|
||
|
LOGGING_NAME = 'ultralytics'
|
||
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
|
||
|
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
|
||
|
HELP_MSG = \
|
||
|
"""
|
||
|
Usage examples for running YOLOv8:
|
||
|
|
||
|
1. Install the ultralytics package:
|
||
|
|
||
|
pip install ultralytics
|
||
|
|
||
|
2. Use the Python SDK:
|
||
|
|
||
|
from ultralytics import YOLO
|
||
|
|
||
|
# Load a model
|
||
|
model = YOLO('yolov8n.yaml') # build a new model from scratch
|
||
|
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
|
||
|
|
||
|
# Use the model
|
||
|
results = model.train(data="coco128.yaml", epochs=3) # train the model
|
||
|
results = model.val() # evaluate model performance on the validation set
|
||
|
results = model('https://ultralytics.com/images/bus.jpg') # predict on an image
|
||
|
success = model.export(format='onnx') # export the model to ONNX format
|
||
|
|
||
|
3. Use the command line interface (CLI):
|
||
|
|
||
|
YOLOv8 'yolo' CLI commands use the following syntax:
|
||
|
|
||
|
yolo TASK MODE ARGS
|
||
|
|
||
|
Where TASK (optional) is one of [detect, segment, classify]
|
||
|
MODE (required) is one of [train, val, predict, export]
|
||
|
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||
|
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
||
|
|
||
|
- Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||
|
|
||
|
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
|
||
|
|
||
|
- Val a pretrained detection model at batch-size 1 and image size 640:
|
||
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||
|
|
||
|
- Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||
|
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||
|
|
||
|
- Run special commands:
|
||
|
yolo help
|
||
|
yolo checks
|
||
|
yolo version
|
||
|
yolo settings
|
||
|
yolo copy-cfg
|
||
|
yolo cfg
|
||
|
|
||
|
Docs: https://docs.ultralytics.com
|
||
|
Community: https://community.ultralytics.com
|
||
|
GitHub: https://github.com/ultralytics/ultralytics
|
||
|
"""
|
||
|
|
||
|
# Settings
|
||
|
torch.set_printoptions(linewidth=320, precision=4, profile='default')
|
||
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||
|
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||
|
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
|
||
|
|
||
|
|
||
|
class TQDM(tqdm_original):
|
||
|
"""
|
||
|
Custom Ultralytics tqdm class with different default arguments.
|
||
|
|
||
|
Args:
|
||
|
*args (list): Positional arguments passed to original tqdm.
|
||
|
**kwargs (dict): Keyword arguments, with custom defaults applied.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
"""Initialize custom Ultralytics tqdm class with different default arguments."""
|
||
|
# Set new default values (these can still be overridden when calling TQDM)
|
||
|
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
|
||
|
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
|
||
|
class SimpleClass:
|
||
|
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||
|
access methods for easier debugging and usage.
|
||
|
"""
|
||
|
|
||
|
def __str__(self):
|
||
|
"""Return a human-readable string representation of the object."""
|
||
|
attr = []
|
||
|
for a in dir(self):
|
||
|
v = getattr(self, a)
|
||
|
if not callable(v) and not a.startswith('_'):
|
||
|
if isinstance(v, SimpleClass):
|
||
|
# Display only the module and class name for subclasses
|
||
|
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
|
||
|
else:
|
||
|
s = f'{a}: {repr(v)}'
|
||
|
attr.append(s)
|
||
|
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
|
||
|
|
||
|
def __repr__(self):
|
||
|
"""Return a machine-readable string representation of the object."""
|
||
|
return self.__str__()
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
"""Custom attribute access error message with helpful information."""
|
||
|
name = self.__class__.__name__
|
||
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||
|
|
||
|
|
||
|
class IterableSimpleNamespace(SimpleNamespace):
|
||
|
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||
|
enables usage with dict() and for loops.
|
||
|
"""
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""Return an iterator of key-value pairs from the namespace's attributes."""
|
||
|
return iter(vars(self).items())
|
||
|
|
||
|
def __str__(self):
|
||
|
"""Return a human-readable string representation of the object."""
|
||
|
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
"""Custom attribute access error message with helpful information."""
|
||
|
name = self.__class__.__name__
|
||
|
raise AttributeError(f"""
|
||
|
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
|
||
|
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
|
||
|
{DEFAULT_CFG_PATH} with the latest version from
|
||
|
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
|
||
|
""")
|
||
|
|
||
|
def get(self, key, default=None):
|
||
|
"""Return the value of the specified key if it exists; otherwise, return the default value."""
|
||
|
return getattr(self, key, default)
|
||
|
|
||
|
|
||
|
def plt_settings(rcparams=None, backend='Agg'):
|
||
|
"""
|
||
|
Decorator to temporarily set rc parameters and the backend for a plotting function.
|
||
|
|
||
|
Example:
|
||
|
decorator: @plt_settings({"font.size": 12})
|
||
|
context manager: with plt_settings({"font.size": 12}):
|
||
|
|
||
|
Args:
|
||
|
rcparams (dict): Dictionary of rc parameters to set.
|
||
|
backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
|
||
|
|
||
|
Returns:
|
||
|
(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
|
||
|
applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
|
||
|
"""
|
||
|
|
||
|
if rcparams is None:
|
||
|
rcparams = {'font.size': 11}
|
||
|
|
||
|
def decorator(func):
|
||
|
"""Decorator to apply temporary rc parameters and backend to a function."""
|
||
|
|
||
|
def wrapper(*args, **kwargs):
|
||
|
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
|
||
|
original_backend = plt.get_backend()
|
||
|
if backend != original_backend:
|
||
|
plt.close('all') # auto-close()ing of figures upon backend switching is deprecated since 3.8
|
||
|
plt.switch_backend(backend)
|
||
|
|
||
|
with plt.rc_context(rcparams):
|
||
|
result = func(*args, **kwargs)
|
||
|
|
||
|
if backend != original_backend:
|
||
|
plt.close('all')
|
||
|
plt.switch_backend(original_backend)
|
||
|
return result
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
def set_logging(name=LOGGING_NAME, verbose=True):
|
||
|
"""Sets up logging for the given name with UTF-8 encoding support."""
|
||
|
level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
|
||
|
|
||
|
# Configure the console (stdout) encoding to UTF-8
|
||
|
formatter = logging.Formatter('%(message)s') # Default formatter
|
||
|
if WINDOWS and sys.stdout.encoding != 'utf-8':
|
||
|
try:
|
||
|
if hasattr(sys.stdout, 'reconfigure'):
|
||
|
sys.stdout.reconfigure(encoding='utf-8')
|
||
|
elif hasattr(sys.stdout, 'buffer'):
|
||
|
import io
|
||
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||
|
else:
|
||
|
sys.stdout.encoding = 'utf-8'
|
||
|
except Exception as e:
|
||
|
print(f'Creating custom formatter for non UTF-8 environments due to {e}')
|
||
|
|
||
|
class CustomFormatter(logging.Formatter):
|
||
|
|
||
|
def format(self, record):
|
||
|
return emojis(super().format(record))
|
||
|
|
||
|
formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse
|
||
|
|
||
|
# Create and configure the StreamHandler
|
||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||
|
stream_handler.setFormatter(formatter)
|
||
|
stream_handler.setLevel(level)
|
||
|
|
||
|
logger = logging.getLogger(name)
|
||
|
logger.setLevel(level)
|
||
|
logger.addHandler(stream_handler)
|
||
|
logger.propagate = False
|
||
|
return logger
|
||
|
|
||
|
|
||
|
# Set logger
|
||
|
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
|
||
|
for logger in 'sentry_sdk', 'urllib3.connectionpool':
|
||
|
logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
|
||
|
|
||
|
|
||
|
def emojis(string=''):
|
||
|
"""Return platform-dependent emoji-safe version of string."""
|
||
|
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
|
||
|
|
||
|
|
||
|
class ThreadingLocked:
|
||
|
"""
|
||
|
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
|
||
|
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
|
||
|
to execute the function.
|
||
|
|
||
|
Attributes:
|
||
|
lock (threading.Lock): A lock object used to manage access to the decorated function.
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
from ultralytics.utils import ThreadingLocked
|
||
|
|
||
|
@ThreadingLocked()
|
||
|
def my_function():
|
||
|
# Your code here
|
||
|
pass
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
"""Initializes the decorator class for thread-safe execution of a function or method."""
|
||
|
self.lock = threading.Lock()
|
||
|
|
||
|
def __call__(self, f):
|
||
|
"""Run thread-safe execution of function or method."""
|
||
|
from functools import wraps
|
||
|
|
||
|
@wraps(f)
|
||
|
def decorated(*args, **kwargs):
|
||
|
"""Applies thread-safety to the decorated function or method."""
|
||
|
with self.lock:
|
||
|
return f(*args, **kwargs)
|
||
|
|
||
|
return decorated
|
||
|
|
||
|
|
||
|
def yaml_save(file='data.yaml', data=None, header=''):
|
||
|
"""
|
||
|
Save YAML data to a file.
|
||
|
|
||
|
Args:
|
||
|
file (str, optional): File name. Default is 'data.yaml'.
|
||
|
data (dict): Data to save in YAML format.
|
||
|
header (str, optional): YAML header to add.
|
||
|
|
||
|
Returns:
|
||
|
(None): Data is saved to the specified file.
|
||
|
"""
|
||
|
if data is None:
|
||
|
data = {}
|
||
|
file = Path(file)
|
||
|
if not file.parent.exists():
|
||
|
# Create parent directories if they don't exist
|
||
|
file.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
# Convert Path objects to strings
|
||
|
valid_types = int, float, str, bool, list, tuple, dict, type(None)
|
||
|
for k, v in data.items():
|
||
|
if not isinstance(v, valid_types):
|
||
|
data[k] = str(v)
|
||
|
|
||
|
# Dump data to file in YAML format
|
||
|
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
|
||
|
if header:
|
||
|
f.write(header)
|
||
|
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||
|
|
||
|
|
||
|
def yaml_load(file='data.yaml', append_filename=False):
|
||
|
"""
|
||
|
Load YAML data from a file.
|
||
|
|
||
|
Args:
|
||
|
file (str, optional): File name. Default is 'data.yaml'.
|
||
|
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
||
|
|
||
|
Returns:
|
||
|
(dict): YAML data and file name.
|
||
|
"""
|
||
|
assert Path(file).suffix in ('.yaml', '.yml'), f'Attempting to load non-YAML file {file} with yaml_load()'
|
||
|
with open(file, errors='ignore', encoding='utf-8') as f:
|
||
|
s = f.read() # string
|
||
|
|
||
|
# Remove special characters
|
||
|
if not s.isprintable():
|
||
|
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
||
|
|
||
|
# Add YAML filename to dict and return
|
||
|
data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)
|
||
|
if append_filename:
|
||
|
data['yaml_file'] = str(file)
|
||
|
return data
|
||
|
|
||
|
|
||
|
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||
|
"""
|
||
|
Pretty prints a YAML file or a YAML-formatted dictionary.
|
||
|
|
||
|
Args:
|
||
|
yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
|
||
|
|
||
|
Returns:
|
||
|
None
|
||
|
"""
|
||
|
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||
|
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||
|
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||
|
|
||
|
|
||
|
# Default configuration
|
||
|
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
|
||
|
for k, v in DEFAULT_CFG_DICT.items():
|
||
|
if isinstance(v, str) and v.lower() == 'none':
|
||
|
DEFAULT_CFG_DICT[k] = None
|
||
|
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||
|
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||
|
|
||
|
|
||
|
def is_ubuntu() -> bool:
|
||
|
"""
|
||
|
Check if the OS is Ubuntu.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if OS is Ubuntu, False otherwise.
|
||
|
"""
|
||
|
with contextlib.suppress(FileNotFoundError):
|
||
|
with open('/etc/os-release') as f:
|
||
|
return 'ID=ubuntu' in f.read()
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_colab():
|
||
|
"""
|
||
|
Check if the current script is running inside a Google Colab notebook.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if running inside a Colab notebook, False otherwise.
|
||
|
"""
|
||
|
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||
|
|
||
|
|
||
|
def is_kaggle():
|
||
|
"""
|
||
|
Check if the current script is running inside a Kaggle kernel.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if running inside a Kaggle kernel, False otherwise.
|
||
|
"""
|
||
|
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||
|
|
||
|
|
||
|
def is_jupyter():
|
||
|
"""
|
||
|
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
||
|
"""
|
||
|
with contextlib.suppress(Exception):
|
||
|
from IPython import get_ipython
|
||
|
return get_ipython() is not None
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_docker() -> bool:
|
||
|
"""
|
||
|
Determine if the script is running inside a Docker container.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if the script is running inside a Docker container, False otherwise.
|
||
|
"""
|
||
|
file = Path('/proc/self/cgroup')
|
||
|
if file.exists():
|
||
|
with open(file) as f:
|
||
|
return 'docker' in f.read()
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_online() -> bool:
|
||
|
"""
|
||
|
Check internet connectivity by attempting to connect to a known online host.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if connection is successful, False otherwise.
|
||
|
"""
|
||
|
import socket
|
||
|
|
||
|
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||
|
try:
|
||
|
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||
|
except (socket.timeout, socket.gaierror, OSError):
|
||
|
continue
|
||
|
else:
|
||
|
# If the connection was successful, close it to avoid a ResourceWarning
|
||
|
test_connection.close()
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
ONLINE = is_online()
|
||
|
|
||
|
|
||
|
def is_pip_package(filepath: str = __name__) -> bool:
|
||
|
"""
|
||
|
Determines if the file at the given filepath is part of a pip package.
|
||
|
|
||
|
Args:
|
||
|
filepath (str): The filepath to check.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if the file is part of a pip package, False otherwise.
|
||
|
"""
|
||
|
import importlib.util
|
||
|
|
||
|
# Get the spec for the module
|
||
|
spec = importlib.util.find_spec(filepath)
|
||
|
|
||
|
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
||
|
return spec is not None and spec.origin is not None
|
||
|
|
||
|
|
||
|
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
||
|
"""
|
||
|
Check if a directory is writeable.
|
||
|
|
||
|
Args:
|
||
|
dir_path (str | Path): The path to the directory.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if the directory is writeable, False otherwise.
|
||
|
"""
|
||
|
return os.access(str(dir_path), os.W_OK)
|
||
|
|
||
|
|
||
|
def is_pytest_running():
|
||
|
"""
|
||
|
Determines whether pytest is currently running or not.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if pytest is running, False otherwise.
|
||
|
"""
|
||
|
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
|
||
|
|
||
|
|
||
|
def is_github_action_running() -> bool:
|
||
|
"""
|
||
|
Determine if the current environment is a GitHub Actions runner.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if the current environment is a GitHub Actions runner, False otherwise.
|
||
|
"""
|
||
|
return 'GITHUB_ACTIONS' in os.environ and 'GITHUB_WORKFLOW' in os.environ and 'RUNNER_OS' in os.environ
|
||
|
|
||
|
|
||
|
def is_git_dir():
|
||
|
"""
|
||
|
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
||
|
repository, returns None.
|
||
|
|
||
|
Returns:
|
||
|
(bool): True if current file is part of a git repository.
|
||
|
"""
|
||
|
return get_git_dir() is not None
|
||
|
|
||
|
|
||
|
def get_git_dir():
|
||
|
"""
|
||
|
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
|
||
|
the current file is not part of a git repository, returns None.
|
||
|
|
||
|
Returns:
|
||
|
(Path | None): Git root directory if found or None if not found.
|
||
|
"""
|
||
|
for d in Path(__file__).parents:
|
||
|
if (d / '.git').is_dir():
|
||
|
return d
|
||
|
|
||
|
|
||
|
def get_git_origin_url():
|
||
|
"""
|
||
|
Retrieves the origin URL of a git repository.
|
||
|
|
||
|
Returns:
|
||
|
(str | None): The origin URL of the git repository or None if not git directory.
|
||
|
"""
|
||
|
if is_git_dir():
|
||
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||
|
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
|
||
|
return origin.decode().strip()
|
||
|
|
||
|
|
||
|
def get_git_branch():
|
||
|
"""
|
||
|
Returns the current git branch name. If not in a git repository, returns None.
|
||
|
|
||
|
Returns:
|
||
|
(str | None): The current git branch name or None if not a git directory.
|
||
|
"""
|
||
|
if is_git_dir():
|
||
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||
|
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||
|
return origin.decode().strip()
|
||
|
|
||
|
|
||
|
def get_default_args(func):
|
||
|
"""
|
||
|
Returns a dictionary of default arguments for a function.
|
||
|
|
||
|
Args:
|
||
|
func (callable): The function to inspect.
|
||
|
|
||
|
Returns:
|
||
|
(dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.
|
||
|
"""
|
||
|
signature = inspect.signature(func)
|
||
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
||
|
|
||
|
|
||
|
def get_ubuntu_version():
|
||
|
"""
|
||
|
Retrieve the Ubuntu version if the OS is Ubuntu.
|
||
|
|
||
|
Returns:
|
||
|
(str): Ubuntu version or None if not an Ubuntu OS.
|
||
|
"""
|
||
|
if is_ubuntu():
|
||
|
with contextlib.suppress(FileNotFoundError, AttributeError):
|
||
|
with open('/etc/os-release') as f:
|
||
|
return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
|
||
|
|
||
|
|
||
|
def get_user_config_dir(sub_dir='Ultralytics'):
|
||
|
"""
|
||
|
Get the user config directory.
|
||
|
|
||
|
Args:
|
||
|
sub_dir (str): The name of the subdirectory to create.
|
||
|
|
||
|
Returns:
|
||
|
(Path): The path to the user config directory.
|
||
|
"""
|
||
|
# Return the appropriate config directory for each operating system
|
||
|
if WINDOWS:
|
||
|
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
|
||
|
elif MACOS: # macOS
|
||
|
path = Path.home() / 'Library' / 'Application Support' / sub_dir
|
||
|
elif LINUX:
|
||
|
path = Path.home() / '.config' / sub_dir
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported operating system: {platform.system()}')
|
||
|
|
||
|
# GCP and AWS lambda fix, only /tmp is writeable
|
||
|
if not is_dir_writeable(path.parent):
|
||
|
LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||
|
'Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.')
|
||
|
path = Path('/tmp') / sub_dir if is_dir_writeable('/tmp') else Path().cwd() / sub_dir
|
||
|
|
||
|
# Create the subdirectory if it does not exist
|
||
|
path.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
return path
|
||
|
|
||
|
|
||
|
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR') or get_user_config_dir()) # Ultralytics settings dir
|
||
|
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
||
|
|
||
|
|
||
|
def colorstr(*input):
|
||
|
"""
|
||
|
Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
|
||
|
See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
|
||
|
|
||
|
This function can be called in two ways:
|
||
|
- colorstr('color', 'style', 'your string')
|
||
|
- colorstr('your string')
|
||
|
|
||
|
In the second form, 'blue' and 'bold' will be applied by default.
|
||
|
|
||
|
Args:
|
||
|
*input (str): A sequence of strings where the first n-1 strings are color and style arguments,
|
||
|
and the last string is the one to be colored.
|
||
|
|
||
|
Supported Colors and Styles:
|
||
|
Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
|
||
|
Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
|
||
|
'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
|
||
|
Misc: 'end', 'bold', 'underline'
|
||
|
|
||
|
Returns:
|
||
|
(str): The input string wrapped with ANSI escape codes for the specified color and style.
|
||
|
|
||
|
Examples:
|
||
|
>>> colorstr('blue', 'bold', 'hello world')
|
||
|
>>> '\033[34m\033[1mhello world\033[0m'
|
||
|
"""
|
||
|
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||
|
colors = {
|
||
|
'black': '\033[30m', # basic colors
|
||
|
'red': '\033[31m',
|
||
|
'green': '\033[32m',
|
||
|
'yellow': '\033[33m',
|
||
|
'blue': '\033[34m',
|
||
|
'magenta': '\033[35m',
|
||
|
'cyan': '\033[36m',
|
||
|
'white': '\033[37m',
|
||
|
'bright_black': '\033[90m', # bright colors
|
||
|
'bright_red': '\033[91m',
|
||
|
'bright_green': '\033[92m',
|
||
|
'bright_yellow': '\033[93m',
|
||
|
'bright_blue': '\033[94m',
|
||
|
'bright_magenta': '\033[95m',
|
||
|
'bright_cyan': '\033[96m',
|
||
|
'bright_white': '\033[97m',
|
||
|
'end': '\033[0m', # misc
|
||
|
'bold': '\033[1m',
|
||
|
'underline': '\033[4m'}
|
||
|
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||
|
|
||
|
|
||
|
def remove_colorstr(input_string):
|
||
|
"""
|
||
|
Removes ANSI escape codes from a string, effectively un-coloring it.
|
||
|
|
||
|
Args:
|
||
|
input_string (str): The string to remove color and style from.
|
||
|
|
||
|
Returns:
|
||
|
(str): A new string with all ANSI escape codes removed.
|
||
|
|
||
|
Examples:
|
||
|
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
||
|
>>> 'hello world'
|
||
|
"""
|
||
|
ansi_escape = re.compile(r'\x1B\[[0-9;]*[A-Za-z]')
|
||
|
return ansi_escape.sub('', input_string)
|
||
|
|
||
|
|
||
|
class TryExcept(contextlib.ContextDecorator):
|
||
|
"""
|
||
|
YOLOv8 TryExcept class.
|
||
|
|
||
|
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, msg='', verbose=True):
|
||
|
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||
|
self.msg = msg
|
||
|
self.verbose = verbose
|
||
|
|
||
|
def __enter__(self):
|
||
|
"""Executes when entering TryExcept context, initializes instance."""
|
||
|
pass
|
||
|
|
||
|
def __exit__(self, exc_type, value, traceback):
|
||
|
"""Defines behavior when exiting a 'with' block, prints error message if necessary."""
|
||
|
if self.verbose and value:
|
||
|
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||
|
return True
|
||
|
|
||
|
|
||
|
def threaded(func):
|
||
|
"""
|
||
|
Multi-threads a target function and returns thread.
|
||
|
|
||
|
Use as @threaded decorator.
|
||
|
"""
|
||
|
|
||
|
def wrapper(*args, **kwargs):
|
||
|
"""Multi-threads a given function and returns the thread."""
|
||
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||
|
thread.start()
|
||
|
return thread
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def set_sentry():
|
||
|
"""
|
||
|
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
|
||
|
sync=True in settings. Run 'yolo settings' to see and update settings YAML file.
|
||
|
|
||
|
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
|
||
|
- sentry_sdk package is installed
|
||
|
- sync=True in YOLO settings
|
||
|
- pytest is not running
|
||
|
- running in a pip package installation
|
||
|
- running in a non-git directory
|
||
|
- running with rank -1 or 0
|
||
|
- online environment
|
||
|
- CLI used to run package (checked with 'yolo' as the name of the main CLI command)
|
||
|
|
||
|
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
|
||
|
exceptions and to exclude events with 'out of memory' in their exception message.
|
||
|
|
||
|
Additionally, the function sets custom tags and user information for Sentry events.
|
||
|
"""
|
||
|
|
||
|
def before_send(event, hint):
|
||
|
"""
|
||
|
Modify the event before sending it to Sentry based on specific exception types and messages.
|
||
|
|
||
|
Args:
|
||
|
event (dict): The event dictionary containing information about the error.
|
||
|
hint (dict): A dictionary containing additional information about the error.
|
||
|
|
||
|
Returns:
|
||
|
dict: The modified event or None if the event should not be sent to Sentry.
|
||
|
"""
|
||
|
if 'exc_info' in hint:
|
||
|
exc_type, exc_value, tb = hint['exc_info']
|
||
|
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
|
||
|
or 'out of memory' in str(exc_value):
|
||
|
return None # do not send event
|
||
|
|
||
|
event['tags'] = {
|
||
|
'sys_argv': sys.argv[0],
|
||
|
'sys_argv_name': Path(sys.argv[0]).name,
|
||
|
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||
|
'os': ENVIRONMENT}
|
||
|
return event
|
||
|
|
||
|
if SETTINGS['sync'] and \
|
||
|
RANK in (-1, 0) and \
|
||
|
Path(sys.argv[0]).name == 'yolo' and \
|
||
|
not TESTS_RUNNING and \
|
||
|
ONLINE and \
|
||
|
is_pip_package() and \
|
||
|
not is_git_dir():
|
||
|
|
||
|
# If sentry_sdk package is not installed then return and do not use Sentry
|
||
|
try:
|
||
|
import sentry_sdk # noqa
|
||
|
except ImportError:
|
||
|
return
|
||
|
|
||
|
sentry_sdk.init(
|
||
|
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
|
||
|
debug=False,
|
||
|
traces_sample_rate=1.0,
|
||
|
release=__version__,
|
||
|
environment='production', # 'dev' or 'production'
|
||
|
before_send=before_send,
|
||
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||
|
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
|
||
|
|
||
|
|
||
|
class SettingsManager(dict):
|
||
|
"""
|
||
|
Manages Ultralytics settings stored in a YAML file.
|
||
|
|
||
|
Args:
|
||
|
file (str | Path): Path to the Ultralytics settings YAML file. Default is USER_CONFIG_DIR / 'settings.yaml'.
|
||
|
version (str): Settings version. In case of local version mismatch, new default settings will be saved.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
|
||
|
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
||
|
file.
|
||
|
"""
|
||
|
import copy
|
||
|
import hashlib
|
||
|
|
||
|
from ultralytics.utils.checks import check_version
|
||
|
from ultralytics.utils.torch_utils import torch_distributed_zero_first
|
||
|
|
||
|
git_dir = get_git_dir()
|
||
|
root = git_dir or Path()
|
||
|
datasets_root = (root.parent if git_dir and is_dir_writeable(root.parent) else root).resolve()
|
||
|
|
||
|
self.file = Path(file)
|
||
|
self.version = version
|
||
|
self.defaults = {
|
||
|
'settings_version': version,
|
||
|
'datasets_dir': str(datasets_root / 'datasets'),
|
||
|
'weights_dir': str(root / 'weights'),
|
||
|
'runs_dir': str(root / 'runs'),
|
||
|
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||
|
'sync': True,
|
||
|
'api_key': '',
|
||
|
'clearml': True, # integrations
|
||
|
'comet': True,
|
||
|
'dvc': True,
|
||
|
'hub': True,
|
||
|
'mlflow': True,
|
||
|
'neptune': True,
|
||
|
'raytune': True,
|
||
|
'tensorboard': True,
|
||
|
'wandb': True}
|
||
|
|
||
|
super().__init__(copy.deepcopy(self.defaults))
|
||
|
|
||
|
with torch_distributed_zero_first(RANK):
|
||
|
if not self.file.exists():
|
||
|
self.save()
|
||
|
|
||
|
self.load()
|
||
|
correct_keys = self.keys() == self.defaults.keys()
|
||
|
correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))
|
||
|
correct_version = check_version(self['settings_version'], self.version)
|
||
|
if not (correct_keys and correct_types and correct_version):
|
||
|
LOGGER.warning(
|
||
|
'WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem '
|
||
|
'with your settings or a recent ultralytics package update. '
|
||
|
f"\nView settings with 'yolo settings' or at '{self.file}'"
|
||
|
"\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'.")
|
||
|
self.reset()
|
||
|
|
||
|
def load(self):
|
||
|
"""Loads settings from the YAML file."""
|
||
|
super().update(yaml_load(self.file))
|
||
|
|
||
|
def save(self):
|
||
|
"""Saves the current settings to the YAML file."""
|
||
|
yaml_save(self.file, dict(self))
|
||
|
|
||
|
def update(self, *args, **kwargs):
|
||
|
"""Updates a setting value in the current settings."""
|
||
|
super().update(*args, **kwargs)
|
||
|
self.save()
|
||
|
|
||
|
def reset(self):
|
||
|
"""Resets the settings to default and saves them."""
|
||
|
self.clear()
|
||
|
self.update(self.defaults)
|
||
|
self.save()
|
||
|
|
||
|
|
||
|
def deprecation_warn(arg, new_arg, version=None):
|
||
|
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
|
||
|
if not version:
|
||
|
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
|
||
|
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||
|
f"Please use '{new_arg}' instead.")
|
||
|
|
||
|
|
||
|
def clean_url(url):
|
||
|
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
||
|
url = Path(url).as_posix().replace(':/', '://') # Pathlib turns :// -> :/, as_posix() for Windows
|
||
|
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||
|
|
||
|
|
||
|
def url2file(url):
|
||
|
"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
|
||
|
return Path(clean_url(url)).name
|
||
|
|
||
|
|
||
|
# Run below code on utils init ------------------------------------------------------------------------------------
|
||
|
|
||
|
# Check first-install steps
|
||
|
PREFIX = colorstr('Ultralytics: ')
|
||
|
SETTINGS = SettingsManager() # initialize settings
|
||
|
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||
|
WEIGHTS_DIR = Path(SETTINGS['weights_dir']) # global weights directory
|
||
|
RUNS_DIR = Path(SETTINGS['runs_dir']) # global runs directory
|
||
|
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
||
|
'Docker' if is_docker() else platform.system()
|
||
|
TESTS_RUNNING = is_pytest_running() or is_github_action_running()
|
||
|
set_sentry()
|
||
|
|
||
|
# Apply monkey patches
|
||
|
from .patches import imread, imshow, imwrite, torch_save
|
||
|
|
||
|
torch.save = torch_save
|
||
|
if WINDOWS:
|
||
|
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
|
||
|
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow
|