477 lines
23 KiB
Python
477 lines
23 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
"""
|
|
Generate predictions using the Segment Anything Model (SAM).
|
|
|
|
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
|
|
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
|
|
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
|
|
segmentation tasks.
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from ultralytics.data.augment import LetterBox
|
|
from ultralytics.engine.predictor import BasePredictor
|
|
from ultralytics.engine.results import Results
|
|
from ultralytics.utils import DEFAULT_CFG, ops
|
|
from ultralytics.utils.torch_utils import select_device
|
|
|
|
from .amg import (
|
|
batch_iterator,
|
|
batched_mask_to_box,
|
|
build_all_layer_point_grids,
|
|
calculate_stability_score,
|
|
generate_crop_boxes,
|
|
is_box_near_crop_edge,
|
|
remove_small_regions,
|
|
uncrop_boxes_xyxy,
|
|
uncrop_masks,
|
|
)
|
|
from .build import build_sam
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
"""
|
|
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
|
|
|
The class provides an interface for model inference tailored to image segmentation tasks.
|
|
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
|
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
|
points, and low-resolution masks.
|
|
|
|
Attributes:
|
|
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
|
overrides (dict): Dictionary containing values that override the default configuration.
|
|
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
|
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
|
im (torch.Tensor): Preprocessed input image tensor.
|
|
features (torch.Tensor): Extracted image features used for inference.
|
|
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
|
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
|
"""
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
"""
|
|
Initialize the Predictor with configuration, overrides, and callbacks.
|
|
|
|
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
|
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
|
|
|
Args:
|
|
cfg (dict): Configuration dictionary.
|
|
overrides (dict, optional): Dictionary of values to override default configuration.
|
|
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
|
"""
|
|
if overrides is None:
|
|
overrides = {}
|
|
overrides.update(dict(task="segment", mode="predict", imgsz=1024))
|
|
super().__init__(cfg, overrides, _callbacks)
|
|
self.args.retina_masks = True
|
|
self.im = None
|
|
self.features = None
|
|
self.prompts = {}
|
|
self.segment_all = False
|
|
|
|
def preprocess(self, im):
|
|
"""
|
|
Preprocess the input image for model inference.
|
|
|
|
The method prepares the input image by applying transformations and normalization.
|
|
It supports both torch.Tensor and list of np.ndarray as input formats.
|
|
|
|
Args:
|
|
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
|
|
|
Returns:
|
|
(torch.Tensor): The preprocessed image tensor.
|
|
"""
|
|
if self.im is not None:
|
|
return self.im
|
|
not_tensor = not isinstance(im, torch.Tensor)
|
|
if not_tensor:
|
|
im = np.stack(self.pre_transform(im))
|
|
im = im[..., ::-1].transpose((0, 3, 1, 2))
|
|
im = np.ascontiguousarray(im)
|
|
im = torch.from_numpy(im)
|
|
|
|
im = im.to(self.device)
|
|
im = im.half() if self.model.fp16 else im.float()
|
|
if not_tensor:
|
|
im = (im - self.mean) / self.std
|
|
return im
|
|
|
|
def pre_transform(self, im):
|
|
"""
|
|
Perform initial transformations on the input image for preprocessing.
|
|
|
|
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
|
Currently, batched inference is not supported; hence the list length should be 1.
|
|
|
|
Args:
|
|
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
|
|
|
Returns:
|
|
(List[np.ndarray]): List of transformed images.
|
|
"""
|
|
assert len(im) == 1, "SAM model does not currently support batched inference"
|
|
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
|
return [letterbox(image=x) for x in im]
|
|
|
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
|
"""
|
|
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
|
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
|
mask decoder for real-time and promptable segmentation tasks.
|
|
|
|
Args:
|
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
|
|
|
Returns:
|
|
(tuple): Contains the following three elements.
|
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
|
"""
|
|
# Override prompts if any stored in self.prompts
|
|
bboxes = self.prompts.pop("bboxes", bboxes)
|
|
points = self.prompts.pop("points", points)
|
|
masks = self.prompts.pop("masks", masks)
|
|
|
|
if all(i is None for i in [bboxes, points, masks]):
|
|
return self.generate(im, *args, **kwargs)
|
|
|
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
|
|
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
|
"""
|
|
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
|
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
|
|
|
Args:
|
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
|
|
|
Returns:
|
|
(tuple): Contains the following three elements.
|
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
|
"""
|
|
features = self.model.image_encoder(im) if self.features is None else self.features
|
|
|
|
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
|
# Transform input prompts
|
|
if points is not None:
|
|
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
|
points = points[None] if points.ndim == 1 else points
|
|
# Assuming labels are all positive if users don't pass labels.
|
|
if labels is None:
|
|
labels = np.ones(points.shape[0])
|
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
|
points *= r
|
|
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
|
points, labels = points[:, None, :], labels[:, None]
|
|
if bboxes is not None:
|
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
|
bboxes *= r
|
|
if masks is not None:
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
|
|
points = (points, labels) if points is not None else None
|
|
# Embed prompts
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
|
|
|
# Predict masks
|
|
pred_masks, pred_scores = self.model.mask_decoder(
|
|
image_embeddings=features,
|
|
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output,
|
|
)
|
|
|
|
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
|
|
def generate(
|
|
self,
|
|
im,
|
|
crop_n_layers=0,
|
|
crop_overlap_ratio=512 / 1500,
|
|
crop_downscale_factor=1,
|
|
point_grids=None,
|
|
points_stride=32,
|
|
points_batch_size=64,
|
|
conf_thres=0.88,
|
|
stability_score_thresh=0.95,
|
|
stability_score_offset=0.95,
|
|
crop_nms_thresh=0.7,
|
|
):
|
|
"""
|
|
Perform image segmentation using the Segment Anything Model (SAM).
|
|
|
|
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
|
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
|
|
|
Args:
|
|
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
|
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
|
Each layer produces 2**i_layer number of image crops.
|
|
crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
|
|
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
|
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
|
Used in the nth crop layer.
|
|
points_stride (int, optional): Number of points to sample along each side of the image.
|
|
Exclusive with 'point_grids'.
|
|
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
|
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
|
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
|
stability_score_offset (float): Offset value for calculating stability score.
|
|
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
|
|
|
Returns:
|
|
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
|
"""
|
|
import torchvision # scope for faster 'import ultralytics'
|
|
|
|
self.segment_all = True
|
|
ih, iw = im.shape[2:]
|
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
|
if point_grids is None:
|
|
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
|
|
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
|
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
|
x1, y1, x2, y2 = crop_region
|
|
w, h = x2 - x1, y2 - y1
|
|
area = torch.tensor(w * h, device=im.device)
|
|
points_scale = np.array([[w, h]]) # w, h
|
|
# Crop image and interpolate to input size
|
|
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
|
|
# (num_points, 2)
|
|
points_for_image = point_grids[layer_idx] * points_scale
|
|
crop_masks, crop_scores, crop_bboxes = [], [], []
|
|
for (points,) in batch_iterator(points_batch_size, points_for_image):
|
|
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
|
# Interpolate predicted masks to input size
|
|
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
|
|
idx = pred_score > conf_thres
|
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
|
|
stability_score = calculate_stability_score(
|
|
pred_mask, self.model.mask_threshold, stability_score_offset
|
|
)
|
|
idx = stability_score > stability_score_thresh
|
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
# Bool type is much more memory-efficient.
|
|
pred_mask = pred_mask > self.model.mask_threshold
|
|
# (N, 4)
|
|
pred_bbox = batched_mask_to_box(pred_mask).float()
|
|
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
|
|
if not torch.all(keep_mask):
|
|
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
|
|
|
|
crop_masks.append(pred_mask)
|
|
crop_bboxes.append(pred_bbox)
|
|
crop_scores.append(pred_score)
|
|
|
|
# Do nms within this crop
|
|
crop_masks = torch.cat(crop_masks)
|
|
crop_bboxes = torch.cat(crop_bboxes)
|
|
crop_scores = torch.cat(crop_scores)
|
|
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
|
|
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
|
|
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
|
|
crop_scores = crop_scores[keep]
|
|
|
|
pred_masks.append(crop_masks)
|
|
pred_bboxes.append(crop_bboxes)
|
|
pred_scores.append(crop_scores)
|
|
region_areas.append(area.expand(len(crop_masks)))
|
|
|
|
pred_masks = torch.cat(pred_masks)
|
|
pred_bboxes = torch.cat(pred_bboxes)
|
|
pred_scores = torch.cat(pred_scores)
|
|
region_areas = torch.cat(region_areas)
|
|
|
|
# Remove duplicate masks between crops
|
|
if len(crop_regions) > 1:
|
|
scores = 1 / region_areas
|
|
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
|
|
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
|
|
|
|
return pred_masks, pred_scores, pred_bboxes
|
|
|
|
def setup_model(self, model, verbose=True):
|
|
"""
|
|
Initializes the Segment Anything Model (SAM) for inference.
|
|
|
|
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
|
parameters for image normalization and other Ultralytics compatibility settings.
|
|
|
|
Args:
|
|
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
|
verbose (bool): If True, prints selected device information.
|
|
|
|
Attributes:
|
|
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
|
device (torch.device): The device to which the model and tensors are allocated.
|
|
mean (torch.Tensor): The mean values for image normalization.
|
|
std (torch.Tensor): The standard deviation values for image normalization.
|
|
"""
|
|
device = select_device(self.args.device, verbose=verbose)
|
|
if model is None:
|
|
model = build_sam(self.args.model)
|
|
model.eval()
|
|
self.model = model.to(device)
|
|
self.device = device
|
|
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
|
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
|
|
|
# Ultralytics compatibility settings
|
|
self.model.pt = False
|
|
self.model.triton = False
|
|
self.model.stride = 32
|
|
self.model.fp16 = False
|
|
self.done_warmup = True
|
|
|
|
def postprocess(self, preds, img, orig_imgs):
|
|
"""
|
|
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
|
|
|
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
|
|
The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
|
|
|
Args:
|
|
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
|
img (torch.Tensor): The processed input image tensor.
|
|
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
|
|
|
Returns:
|
|
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
|
"""
|
|
# (N, 1, H, W), (N, 1)
|
|
pred_masks, pred_scores = preds[:2]
|
|
pred_bboxes = preds[2] if self.segment_all else None
|
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
|
|
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
|
|
results = []
|
|
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
|
|
if pred_bboxes is not None:
|
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
|
masks = masks > self.model.mask_threshold # to bool
|
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
|
# Reset segment-all mode.
|
|
self.segment_all = False
|
|
return results
|
|
|
|
def setup_source(self, source):
|
|
"""
|
|
Sets up the data source for inference.
|
|
|
|
This method configures the data source from which images will be fetched for inference. The source could be a
|
|
directory, a video file, or other types of image data sources.
|
|
|
|
Args:
|
|
source (str | Path): The path to the image data source for inference.
|
|
"""
|
|
if source is not None:
|
|
super().setup_source(source)
|
|
|
|
def set_image(self, image):
|
|
"""
|
|
Preprocesses and sets a single image for inference.
|
|
|
|
This function sets up the model if not already initialized, configures the data source to the specified image,
|
|
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
|
|
|
Args:
|
|
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
|
|
|
Raises:
|
|
AssertionError: If more than one image is set.
|
|
"""
|
|
if self.model is None:
|
|
model = build_sam(self.args.model)
|
|
self.setup_model(model)
|
|
self.setup_source(image)
|
|
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
|
for batch in self.dataset:
|
|
im = self.preprocess(batch[1])
|
|
self.features = self.model.image_encoder(im)
|
|
self.im = im
|
|
break
|
|
|
|
def set_prompts(self, prompts):
|
|
"""Set prompts in advance."""
|
|
self.prompts = prompts
|
|
|
|
def reset_image(self):
|
|
"""Resets the image and its features to None."""
|
|
self.im = None
|
|
self.features = None
|
|
|
|
@staticmethod
|
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
|
"""
|
|
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
|
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
|
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
|
|
|
Args:
|
|
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
|
the number of masks, H is height, and W is width.
|
|
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
|
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
|
|
|
Returns:
|
|
(tuple([torch.Tensor, List[int]])):
|
|
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
|
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
|
"""
|
|
import torchvision # scope for faster 'import ultralytics'
|
|
|
|
if len(masks) == 0:
|
|
return masks
|
|
|
|
# Filter small disconnected regions and holes
|
|
new_masks = []
|
|
scores = []
|
|
for mask in masks:
|
|
mask = mask.cpu().numpy().astype(np.uint8)
|
|
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
|
unchanged = not changed
|
|
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
|
unchanged = unchanged and not changed
|
|
|
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
|
# Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
|
|
scores.append(float(unchanged))
|
|
|
|
# Recalculate boxes and remove any new duplicates
|
|
new_masks = torch.cat(new_masks, dim=0)
|
|
boxes = batched_mask_to_box(new_masks)
|
|
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
|
|
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|