84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
"""
|
|
YOLO-NAS model interface.
|
|
|
|
Example:
|
|
```python
|
|
from ultralytics import NAS
|
|
|
|
model = NAS('yolo_nas_s')
|
|
results = model.predict('ultralytics/assets/bus.jpg')
|
|
```
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from ultralytics.engine.model import Model
|
|
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
|
|
|
from .predict import NASPredictor
|
|
from .val import NASValidator
|
|
|
|
|
|
class NAS(Model):
|
|
"""
|
|
YOLO NAS model for object detection.
|
|
|
|
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
|
|
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
|
|
|
Example:
|
|
```python
|
|
from ultralytics import NAS
|
|
|
|
model = NAS('yolo_nas_s')
|
|
results = model.predict('ultralytics/assets/bus.jpg')
|
|
```
|
|
|
|
Attributes:
|
|
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
|
|
|
|
Note:
|
|
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
|
"""
|
|
|
|
def __init__(self, model='yolo_nas_s.pt') -> None:
|
|
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
|
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
|
super().__init__(model, task='detect')
|
|
|
|
@smart_inference_mode()
|
|
def _load(self, weights: str, task: str):
|
|
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
|
import super_gradients
|
|
suffix = Path(weights).suffix
|
|
if suffix == '.pt':
|
|
self.model = torch.load(weights)
|
|
elif suffix == '':
|
|
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
|
|
# Standardize model
|
|
self.model.fuse = lambda verbose=True: self.model
|
|
self.model.stride = torch.tensor([32])
|
|
self.model.names = dict(enumerate(self.model._class_names))
|
|
self.model.is_fused = lambda: False # for info()
|
|
self.model.yaml = {} # for info()
|
|
self.model.pt_path = weights # for export()
|
|
self.model.task = 'detect' # for export()
|
|
|
|
def info(self, detailed=False, verbose=True):
|
|
"""
|
|
Logs model info.
|
|
|
|
Args:
|
|
detailed (bool): Show detailed information about model.
|
|
verbose (bool): Controls verbosity.
|
|
"""
|
|
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
|
|
|
@property
|
|
def task_map(self):
|
|
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
|
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|