A vision library for performing sliced inference on large images/small objects
—
SAHI provides advanced postprocessing methods for combining overlapping predictions from sliced inference. These algorithms intelligently merge predictions to eliminate duplicates and improve detection accuracy across slice boundaries.
Base class for all prediction postprocessing algorithms with configurable matching criteria.
class PostprocessPredictions:
def __init__(
self,
match_threshold: float = 0.5,
match_metric: str = "IOS",
class_agnostic: bool = False,
):
"""
Initialize postprocessing configuration.
Parameters:
- match_threshold (float): Overlap threshold for matching predictions (0-1)
- match_metric (str): Overlap calculation method ("IOU" or "IOS")
- "IOU": Intersection over Union
- "IOS": Intersection over Smaller area
- class_agnostic (bool): Whether to ignore class when matching predictions
"""
def __call__(
self,
object_predictions: List[ObjectPrediction],
) -> List[ObjectPrediction]:
"""
Apply postprocessing to prediction list.
Parameters:
- object_predictions: List of ObjectPrediction instances
Returns:
List[ObjectPrediction]: Processed predictions with duplicates removed
"""Classic NMS algorithm that removes predictions with high overlap, keeping only the highest confidence detection.
class NMSPostprocess(PostprocessPredictions):
"""
Non-Maximum Suppression postprocessing.
Removes overlapping predictions, keeping only the highest confidence detection.
"""
def nms(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Non-Maximum Suppression algorithm implementation.
Parameters:
- predictions (np.ndarray): Predictions array with bbox and scores
- match_threshold (float): IoU threshold for suppression
- class_agnostic (bool): Whether to apply NMS across all classes
Returns:
List[int]: Indices of predictions to keep
"""
def batched_nms(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Batched Non-Maximum Suppression for efficient processing.
Parameters:
- predictions (np.ndarray): Predictions with bbox, scores, and classes
- match_threshold (float): IoU threshold for suppression
- class_agnostic (bool): Apply NMS across all classes
Returns:
List[int]: Indices of kept predictions
"""Advanced algorithm that merges overlapping predictions instead of simply removing them, preserving information from multiple detections.
class NMMPostprocess(PostprocessPredictions):
"""
Non-Maximum Merging postprocessing.
Merges overlapping predictions instead of removing them, combining confidence scores
and bounding box coordinates to create more accurate final predictions.
"""
def nmm(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Non-Maximum Merging algorithm implementation.
Parameters:
- predictions (np.ndarray): Predictions array with bbox and scores
- match_threshold (float): Overlap threshold for merging
- class_agnostic (bool): Whether to merge across all classes
Returns:
List[int]: Indices of final merged predictions
"""
def batched_nmm(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Batched Non-Maximum Merging for efficient processing.
Parameters:
- predictions (np.ndarray): Predictions with bbox, scores, and classes
- match_threshold (float): Overlap threshold for merging
- class_agnostic (bool): Merge across all classes
Returns:
List[int]: Indices of merged predictions
"""Greedy variant of NMM that processes predictions in confidence order for improved performance on sliced inference results.
class GreedyNMMPostprocess(PostprocessPredictions):
"""
Greedy Non-Maximum Merging postprocessing.
Processes predictions in descending confidence order, greedily merging
overlapping detections. Optimized for sliced inference scenarios.
"""
def greedy_nmm(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Greedy Non-Maximum Merging algorithm implementation.
Parameters:
- predictions (np.ndarray): Predictions array with bbox and scores
- match_threshold (float): Overlap threshold for merging
- class_agnostic (bool): Whether to merge across classes
Returns:
List[int]: Indices of merged predictions
"""
def batched_greedy_nmm(
predictions: np.ndarray,
match_threshold: float = 0.5,
class_agnostic: bool = False,
) -> List[int]:
"""
Batched Greedy Non-Maximum Merging for efficient processing.
Parameters:
- predictions (np.ndarray): Predictions with bbox, scores, and classes
- match_threshold (float): Overlap threshold for merging
- class_agnostic (bool): Merge predictions across all classes
Returns:
List[int]: Indices of kept predictions after merging
"""Soft NMS variant that gradually reduces confidence scores of overlapping predictions instead of hard removal.
class LSNMSPostprocess(PostprocessPredictions):
"""
Linear Soft NMS postprocessing.
Applies soft suppression by linearly reducing confidence scores of overlapping
predictions instead of hard removal, preserving more detections.
"""POSTPROCESS_NAME_TO_CLASS = {
"GREEDYNMM": GreedyNMMPostprocess,
"NMM": NMMPostprocess,
"NMS": NMSPostprocess,
"LSNMS": LSNMSPostprocess,
}from sahi.postprocess.combine import NMSPostprocess
from sahi import get_sliced_prediction
# Create NMS postprocessor
nms_postprocess = NMSPostprocess(
match_threshold=0.5,
match_metric="IOU",
class_agnostic=False
)
# Apply to sliced prediction
result = get_sliced_prediction(
image="large_image.jpg",
detection_model=model,
slice_height=640,
slice_width=640,
postprocess=nms_postprocess
)
print(f"Found {len(result.object_prediction_list)} objects after NMS")from sahi.postprocess.combine import GreedyNMMPostprocess
# Recommended for sliced inference
greedy_nmm = GreedyNMMPostprocess(
match_threshold=0.5,
match_metric="IOS", # Intersection over Smaller area
class_agnostic=False
)
result = get_sliced_prediction(
image="satellite_image.tif",
detection_model=model,
slice_height=1024,
slice_width=1024,
postprocess=greedy_nmm
)from sahi.postprocess.combine import (
NMSPostprocess,
GreedyNMMPostprocess,
NMMPostprocess
)
# Test different postprocessing approaches
postprocessors = {
"NMS": NMSPostprocess(match_threshold=0.5),
"NMM": NMMPostprocess(match_threshold=0.5),
"GreedyNMM": GreedyNMMPostprocess(match_threshold=0.5)
}
results = {}
for name, postprocessor in postprocessors.items():
result = get_sliced_prediction(
image="test_image.jpg",
detection_model=model,
postprocess=postprocessor
)
results[name] = len(result.object_prediction_list)
print(f"{name}: {results[name]} detections")from sahi.postprocess.combine import GreedyNMMPostprocess
# Fine-tuned for specific use case
custom_postprocess = GreedyNMMPostprocess(
match_threshold=0.3, # Lower threshold for aggressive merging
match_metric="IOS", # Use Intersection over Smaller area
class_agnostic=True # Merge across different classes
)
# Apply with sliced prediction
result = get_sliced_prediction(
image="crowded_scene.jpg",
detection_model=model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.3, # Higher overlap
overlap_width_ratio=0.3,
postprocess=custom_postprocess,
verbose=2
)from sahi.predict import get_sliced_prediction
# Use string identifiers for postprocessing
result = get_sliced_prediction(
image="image.jpg",
detection_model=model,
postprocess_type="GREEDYNMM", # Algorithm name
postprocess_match_metric="IOS", # Overlap metric
postprocess_match_threshold=0.5, # Threshold
postprocess_class_agnostic=False # Class-aware processing
)from sahi.postprocess.combine import greedy_nmm, nms
import numpy as np
# Prepare predictions array [x1, y1, x2, y2, score, class_id]
predictions = np.array([
[10, 10, 50, 50, 0.9, 0], # High confidence person
[15, 15, 55, 55, 0.7, 0], # Overlapping person detection
[100, 100, 150, 150, 0.8, 1] # Car detection
])
# Apply Greedy NMM directly
kept_indices = greedy_nmm(
predictions=predictions,
match_threshold=0.5,
class_agnostic=False
)
final_predictions = predictions[kept_indices]
print(f"Kept {len(final_predictions)} predictions after merging")
# Compare with standard NMS
nms_indices = nms(
predictions=predictions,
match_threshold=0.5,
class_agnostic=False
)
print(f"NMS would keep {len(nms_indices)} predictions")# Class-aware: merge only predictions of same class
class_aware = GreedyNMMPostprocess(
match_threshold=0.5,
class_agnostic=False # Only merge same-class predictions
)
# Class-agnostic: merge any overlapping predictions
class_agnostic = GreedyNMMPostprocess(
match_threshold=0.5,
class_agnostic=True # Merge overlapping predictions regardless of class
)
# Compare results
aware_result = get_sliced_prediction(
image="multi_class_scene.jpg",
detection_model=model,
postprocess=class_aware
)
agnostic_result = get_sliced_prediction(
image="multi_class_scene.jpg",
detection_model=model,
postprocess=class_agnostic
)
print(f"Class-aware: {len(aware_result.object_prediction_list)} detections")
print(f"Class-agnostic: {len(agnostic_result.object_prediction_list)} detections")Install with Tessl CLI
npx tessl i tessl/pypi-sahi