A vision library for performing sliced inference on large images/small objects
npx @tessl/cli install tessl/pypi-sahi@0.11.0A comprehensive computer vision library specifically designed for performing large-scale object detection and instance segmentation on high-resolution images. SAHI addresses the challenge of detecting small objects in large images by implementing sliced inference techniques that divide large images into smaller, overlapping patches for processing.
pip install sahiimport sahi
from sahi import AutoDetectionModelCore classes and functions:
from sahi import (
BoundingBox,
Category,
Mask,
AutoDetectionModel,
DetectionModel,
ObjectPrediction
)Prediction functions:
from sahi.predict import get_prediction, get_sliced_prediction, predictfrom sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# Load a detection model
detection_model = AutoDetectionModel.from_pretrained(
model_type='ultralytics',
model_path='yolov8n.pt',
confidence_threshold=0.3,
device="cpu"
)
# Perform sliced inference on a large image
result = get_sliced_prediction(
image="path/to/large_image.jpg",
detection_model=detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
# Access predictions
predictions = result.object_prediction_list
for prediction in predictions:
print(f"Class: {prediction.category.name}")
print(f"Confidence: {prediction.score.value}")
print(f"BBox: {prediction.bbox.to_coco_bbox()}")
# Export visualization
result.export_visuals(export_dir="output/")SAHI's architecture centers around three key concepts:
The library seamlessly integrates with popular frameworks while providing consistent APIs for slicing, prediction postprocessing, dataset operations, and visualization across research and production environments.
Unified interface for loading detection models from various frameworks including Ultralytics YOLO, MMDetection, Detectron2, HuggingFace Transformers, TorchVision, and Roboflow.
class AutoDetectionModel:
@staticmethod
def from_pretrained(
model_type: str,
model_path: Optional[str] = None,
model: Optional[Any] = None,
config_path: Optional[str] = None,
device: Optional[str] = None,
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
category_remapping: Optional[Dict] = None,
load_at_init: bool = True,
image_size: Optional[int] = None,
**kwargs,
) -> DetectionModel: ...Main prediction capabilities including standard inference, sliced inference for large images, batch processing, and video processing with comprehensive parameter control.
def get_prediction(
image,
detection_model,
shift_amount: list = [0, 0],
full_shape=None,
postprocess: Optional[PostprocessPredictions] = None,
verbose: int = 0,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult: ...
def get_sliced_prediction(
image,
detection_model,
slice_height: Optional[int] = None,
slice_width: Optional[int] = None,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
perform_standard_pred: bool = True,
postprocess_type: str = "GREEDYNMM",
postprocess_match_metric: str = "IOS",
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
verbose: int = 1,
merge_buffer_length: Optional[int] = None,
auto_slice_resolution: bool = True,
slice_export_prefix: Optional[str] = None,
slice_dir: Optional[str] = None,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult: ...
def predict(
detection_model: Optional[DetectionModel] = None,
model_type: str = "ultralytics",
model_path: Optional[str] = None,
model_config_path: Optional[str] = None,
model_confidence_threshold: float = 0.25,
model_device: Optional[str] = None,
model_category_mapping: Optional[dict] = None,
model_category_remapping: Optional[dict] = None,
source: Optional[str] = None,
no_standard_prediction: bool = False,
no_sliced_prediction: bool = False,
image_size: Optional[int] = None,
slice_height: int = 512,
slice_width: int = 512,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
postprocess_type: str = "GREEDYNMM",
postprocess_match_metric: str = "IOS",
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
novisual: bool = False,
view_video: bool = False,
frame_skip_interval: int = 0,
export_pickle: bool = False,
export_crop: bool = False,
dataset_json_path: Optional[str] = None,
project: str = "runs/predict",
name: str = "exp",
visual_bbox_thickness: Optional[int] = None,
visual_text_size: Optional[float] = None,
visual_text_thickness: Optional[int] = None,
visual_hide_labels: bool = False,
visual_hide_conf: bool = False,
visual_export_format: str = "png",
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
**kwargs,
) -> Optional[Dict]: ...Core data structures for handling bounding boxes, masks, categories, and complete object annotations with comprehensive format conversion and manipulation methods.
@dataclass(frozen=True)
class BoundingBox:
box: Union[Tuple[float, float, float, float], List[float]]
shift_amount: Tuple[int, int] = (0, 0)
def get_expanded_box(self, ratio: float = 0.1, max_x: int = None, max_y: int = None) -> "BoundingBox": ...
def to_coco_bbox(self) -> List[float]: ...
def to_xyxy(self) -> List[float]: ...
def get_shifted_box(self) -> "BoundingBox": ...
@dataclass(frozen=True)
class Category:
id: Optional[Union[int, str]] = None
name: Optional[str] = None
class Mask:
def __init__(self, bool_mask: Optional[np.ndarray] = None, segmentation: Optional[List] = None, shift_amount: Tuple[int, int] = (0, 0)): ...
@classmethod
def from_float_mask(cls, mask: np.ndarray, mask_threshold: float = 0.5, shift_amount: Tuple[int, int] = (0, 0)) -> "Mask": ...
@classmethod
def from_bool_mask(cls, mask: np.ndarray, shift_amount: Tuple[int, int] = (0, 0)) -> "Mask": ...
def get_shifted_mask(self) -> "Mask": ...
class ObjectPrediction(ObjectAnnotation):
def __init__(
self,
bbox: Optional[BoundingBox] = None,
category: Optional[Category] = None,
score: Optional[PredictionScore] = None,
mask: Optional[Mask] = None,
shift_amount: Optional[List[int]] = None,
full_shape: Optional[List[int]] = None,
): ...
def get_shifted_object_prediction(self) -> "ObjectPrediction": ...
def to_coco_prediction(self) -> CocoPrediction: ...
def to_fiftyone_detection(self): ...Advanced image slicing capabilities for handling large images, including automatic parameter calculation, annotation processing, and dataset slicing operations.
def get_slice_bboxes(
image_height: int,
image_width: int,
slice_height: Optional[int] = None,
slice_width: Optional[int] = None,
auto_slice_resolution: Optional[bool] = True,
overlap_height_ratio: Optional[float] = 0.2,
overlap_width_ratio: Optional[float] = 0.2,
) -> List[List[int]]: ...
def slice_image(
image: Union[str, Image.Image],
output_file_name: Optional[str] = None,
output_dir: Optional[str] = None,
slice_height: int = 512,
slice_width: int = 512,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
auto_slice_resolution: bool = True,
min_area_ratio: float = 0.1,
out_ext: Optional[str] = None,
verbose: bool = False,
) -> SliceImageResult: ...
def slice_coco(
coco_annotation_file_path: str,
image_dir: str,
output_coco_annotation_file_name: str = "",
output_dir: Optional[str] = None,
ignore_negative_samples: bool = False,
slice_height: int = 512,
slice_width: int = 512,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
min_area_ratio: float = 0.1,
verbose: bool = False,
) -> str: ...Advanced postprocessing methods for combining overlapping predictions including Non-Maximum Suppression (NMS), Non-Maximum Merging (NMM), and specialized algorithms for sliced inference results.
class PostprocessPredictions:
def __init__(
self,
match_threshold: float = 0.5,
match_metric: str = "IOS",
class_agnostic: bool = False,
): ...
def __call__(
self,
object_predictions: List[ObjectPrediction],
) -> List[ObjectPrediction]: ...
class NMSPostprocess(PostprocessPredictions): ...
class NMMPostprocess(PostprocessPredictions): ...
class GreedyNMMPostprocess(PostprocessPredictions): ...
class LSNMSPostprocess(PostprocessPredictions): ...
def nms(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]: ...
def greedy_nmm(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]: ...Comprehensive COCO dataset handling including loading, manipulation, annotation processing, evaluation, and format conversion capabilities.
class Coco:
def __init__(self, coco_path: Optional[str] = None): ...
def add_image(self, coco_image: CocoImage) -> int: ...
def add_annotation(self, coco_annotation: CocoAnnotation) -> int: ...
def add_category(self, coco_category: CocoCategory) -> int: ...
def merge(self, coco2: "Coco") -> "Coco": ...
def export_as_yolo(
self,
output_dir: str,
train_split_rate: float = 1.0,
numpy_seed: int = 0,
mp: bool = True,
): ...
class CocoImage:
def __init__(self, image_path: str, image_id: Optional[int] = None): ...
def add_annotation(self, annotation: CocoAnnotation): ...
class CocoAnnotation:
def __init__(
self,
bbox: Optional[List[int]] = None,
category_id: Optional[int] = None,
category_name: Optional[str] = None,
iscrowd: int = 0,
area: Optional[int] = None,
segmentation: Optional[List] = None,
image_id: Optional[int] = None,
annotation_id: Optional[int] = None,
): ...
def create_coco_dict() -> Dict: ...
def export_coco_as_yolo(
coco_path: str,
output_dir: str,
train_split_rate: float = 1.0,
numpy_seed: int = 0,
) -> str: ...Complete command-line interface for prediction, dataset processing, evaluation, and format conversion operations accessible through the sahi command.
# Main prediction command
sahi predict --model_type ultralytics --model_path yolov8n.pt --source image.jpg
# Prediction with FiftyOne integration
sahi predict-fiftyone --model_type ultralytics --model_path yolov8n.pt --source image.jpg
# COCO dataset operations
sahi coco slice --image_dir images/ --dataset_json_path dataset.json
sahi coco evaluate --dataset_json_path dataset.json --result_json_path results.json
sahi coco yolo --coco_annotation_file_path dataset.json --image_dir images/
sahi coco analyse --dataset_json_path dataset.json --result_json_path results.json
sahi coco fiftyone --coco_annotation_file_path dataset.json --image_dir images/
# Environment and version info
sahi version
sahi envUtility functions for computer vision operations, framework-specific integrations, file I/O operations, and compatibility across different deep learning ecosystems.
# CV utilities
def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = True) -> Image.Image: ...
def read_image(image_path: str) -> np.ndarray: ...
def visualize_object_predictions(
image: np.ndarray,
object_prediction_list: List[ObjectPrediction],
rect_th: int = 3,
text_size: float = 3,
text_th: float = 3,
color: tuple = None,
hide_labels: bool = False,
hide_conf: bool = False,
output_dir: Optional[str] = None,
file_name: Optional[str] = "prediction_visual",
) -> np.ndarray: ...
def crop_object_predictions(
image: np.ndarray,
object_prediction_list: List[ObjectPrediction],
output_dir: str,
file_name: str,
export_format: str = "PNG",
) -> None: ...
def get_video_reader(video_path: str): ...
# File utilities
def save_json(data, save_path: str, indent: Optional[int] = None): ...
def load_json(load_path: str, encoding: str = "utf-8") -> Dict: ...
def save_pickle(data: Any, save_path: str): ...
def load_pickle(load_path: str) -> Any: ...
def list_files(
directory: str,
contains: List[str] = None,
verbose: bool = True,
max_depth: Optional[int] = None,
) -> List[str]: ...
def get_base_filename(path: str) -> str: ...
def get_file_extension(path: str) -> str: ...
def download_from_url(from_url: str, to_path: str): ...
# Import utilities
def is_available(package: str) -> bool: ...
def check_requirements(requirements: List[str], raise_exception: bool = True): ...class PredictionResult:
def __init__(
self,
object_prediction_list: List[ObjectPrediction],
image: Image.Image,
durations_in_seconds: Optional[Dict] = None,
): ...
def export_visuals(self, export_dir: str, text_size: float = None): ...
def to_coco_annotations(self) -> List[CocoAnnotation]: ...
def to_coco_predictions(self) -> List[CocoPrediction]: ...
class PredictionScore:
def __init__(self, value: Union[float, np.ndarray]): ...
def is_greater_than_threshold(self, threshold: float) -> bool: ...
class SliceImageResult:
def __init__(self, original_image_size: List[int], image_dir: str): ...
class SlicedImage:
def __init__(self, image: Image.Image, coco_image: CocoImage, starting_pixel: List[int]): ...