0
# Prediction Functions
1
2
SAHI's core prediction capabilities provide standard inference, sliced inference for large images, batch processing, and comprehensive parameter control. These functions form the foundation of SAHI's slicing-aided inference approach.
3
4
## Capabilities
5
6
### Standard Prediction
7
8
Performs object detection on a single image without slicing. Suitable for regular-sized images or when slicing is not needed.
9
10
```python { .api }
11
def get_prediction(
12
image,
13
detection_model,
14
shift_amount: list = [0, 0],
15
full_shape=None,
16
postprocess: Optional[PostprocessPredictions] = None,
17
verbose: int = 0,
18
exclude_classes_by_name: Optional[List[str]] = None,
19
exclude_classes_by_id: Optional[List[int]] = None,
20
) -> PredictionResult:
21
"""
22
Perform detection prediction on a single image.
23
24
Parameters:
25
- image: Image path (str) or numpy array
26
- detection_model: Loaded DetectionModel instance
27
- shift_amount (list): Coordinate shift [shift_x, shift_y] for prediction mapping
28
- full_shape: Original image shape [height, width] if using crops
29
- postprocess: PostprocessPredictions instance for combining predictions
30
- verbose (int): Verbosity level (0=silent, 1=print duration)
31
- exclude_classes_by_name: List of class names to exclude from results
32
- exclude_classes_by_id: List of class IDs to exclude from results
33
34
Returns:
35
PredictionResult: Container with predictions, image, and timing info
36
"""
37
```
38
39
### Sliced Inference
40
41
The core SAHI functionality that slices large images into overlapping patches, performs inference on each patch, and intelligently combines results.
42
43
```python { .api }
44
def get_sliced_prediction(
45
image,
46
detection_model,
47
slice_height: Optional[int] = None,
48
slice_width: Optional[int] = None,
49
overlap_height_ratio: float = 0.2,
50
overlap_width_ratio: float = 0.2,
51
perform_standard_pred: bool = True,
52
postprocess_type: str = "GREEDYNMM",
53
postprocess_match_metric: str = "IOS",
54
postprocess_match_threshold: float = 0.5,
55
postprocess_class_agnostic: bool = False,
56
verbose: int = 1,
57
merge_buffer_length: Optional[int] = None,
58
auto_slice_resolution: bool = True,
59
slice_export_prefix: Optional[str] = None,
60
slice_dir: Optional[str] = None,
61
exclude_classes_by_name: Optional[List[str]] = None,
62
exclude_classes_by_id: Optional[List[int]] = None,
63
) -> PredictionResult:
64
"""
65
Perform sliced inference on large images for better small object detection.
66
67
Parameters:
68
- image: Image path (str) or numpy array
69
- detection_model: Loaded DetectionModel instance
70
- slice_height (int, optional): Height of each slice in pixels
71
- slice_width (int, optional): Width of each slice in pixels
72
- overlap_height_ratio (float): Vertical overlap ratio between slices (0-1)
73
- overlap_width_ratio (float): Horizontal overlap ratio between slices (0-1)
74
- perform_standard_pred (bool): Perform standard prediction on full image in addition to sliced prediction
75
- postprocess_type (str): Postprocessing method ("GREEDYNMM", "NMM", "NMS", "LSNMS")
76
- postprocess_match_metric (str): Overlap metric for combining predictions ("IOU", "IOS")
77
- postprocess_match_threshold (float): Overlap threshold for merging (0-1)
78
- postprocess_class_agnostic (bool): Whether to ignore class when merging
79
- verbose (int): Verbosity level (0=silent, 1=progress, 2=detailed)
80
- merge_buffer_length (int, optional): Buffer length for low memory sliced prediction
81
- auto_slice_resolution (bool): Auto-calculate slice dimensions from image size
82
- slice_export_prefix (str, optional): Prefix for exported slice files
83
- slice_dir (str, optional): Directory to save slice images
84
- exclude_classes_by_name: List of class names to exclude
85
- exclude_classes_by_id: List of class IDs to exclude
86
87
Returns:
88
PredictionResult: Combined predictions from all slices
89
"""
90
```
91
92
### Comprehensive Prediction Pipeline
93
94
High-level prediction function with extensive configuration options for batch processing, video processing, and output management.
95
96
```python { .api }
97
def predict(
98
model_type: str = "yolov8",
99
model_path: Optional[str] = None,
100
model_device: str = None,
101
model_confidence_threshold: float = 0.25,
102
source: Optional[str] = None,
103
slice_height: int = None,
104
slice_width: int = None,
105
overlap_height_ratio: float = 0.2,
106
overlap_width_ratio: float = 0.2,
107
postprocess_type: str = "GREEDYNMM",
108
postprocess_match_metric: str = "IOS",
109
postprocess_match_threshold: float = 0.5,
110
postprocess_class_agnostic: bool = False,
111
export_pickle: bool = False,
112
export_crop: bool = False,
113
export_visual: bool = True,
114
project: str = "runs/predict",
115
name: str = "exp",
116
return_dict: bool = False,
117
force_postprocess: bool = False,
118
frame_skip_interval: int = 0,
119
export_format: str = "coco",
120
verbose: int = 1,
121
crop_class_agnostic: bool = True,
122
desired_name2id: Optional[Dict[str, int]] = None,
123
auto_slice_resolution: bool = True,
124
) -> Optional[Dict]:
125
"""
126
Comprehensive prediction pipeline with model loading, inference, and export.
127
128
Parameters:
129
- model_type (str): Detection framework ("ultralytics", "mmdet", etc.)
130
- model_path (str): Path to model weights
131
- model_device (str): Device for inference ("cpu", "cuda", etc.)
132
- model_confidence_threshold (float): Minimum confidence for detections
133
- source (str): Input path (image, directory, or video file)
134
- slice_height (int): Slice height in pixels (None for auto)
135
- slice_width (int): Slice width in pixels (None for auto)
136
- overlap_height_ratio (float): Vertical overlap between slices
137
- overlap_width_ratio (float): Horizontal overlap between slices
138
- postprocess_type (str): Postprocessing algorithm
139
- postprocess_match_metric (str): Overlap calculation method
140
- postprocess_match_threshold (float): Threshold for combining predictions
141
- postprocess_class_agnostic (bool): Class-agnostic postprocessing
142
- export_pickle (bool): Save predictions as pickle files
143
- export_crop (bool): Export cropped detected objects
144
- export_visual (bool): Export visualization images
145
- project (str): Base directory for outputs
146
- name (str): Experiment name for output subdirectory
147
- return_dict (bool): Return results as dictionary
148
- force_postprocess (bool): Force postprocessing even for single predictions
149
- frame_skip_interval (int): Skip frames in video processing
150
- export_format (str): Output format ("coco", "yolo", "fiftyone")
151
- verbose (int): Verbosity level
152
- crop_class_agnostic (bool): Class-agnostic cropping
153
- desired_name2id (Dict): Custom category name to ID mapping
154
- auto_slice_resolution (bool): Auto-calculate slice parameters
155
156
Returns:
157
Dict or None: Prediction results if return_dict=True
158
"""
159
```
160
161
### FiftyOne Integration
162
163
Specialized prediction function for FiftyOne datasets with seamless integration and result management.
164
165
```python { .api }
166
def predict_fiftyone(
167
model_type: str = "mmdet",
168
model_path: Optional[str] = None,
169
model_config_path: Optional[str] = None,
170
model_confidence_threshold: float = 0.25,
171
model_device: Optional[str] = None,
172
model_category_mapping: Optional[dict] = None,
173
model_category_remapping: Optional[dict] = None,
174
dataset_json_path: str = "",
175
image_dir: str = "",
176
no_standard_prediction: bool = False,
177
no_sliced_prediction: bool = False,
178
image_size: Optional[int] = None,
179
slice_height: int = 256,
180
slice_width: int = 256,
181
overlap_height_ratio: float = 0.2,
182
overlap_width_ratio: float = 0.2,
183
postprocess_type: str = "GREEDYNMM",
184
postprocess_match_metric: str = "IOS",
185
postprocess_match_threshold: float = 0.5,
186
postprocess_class_agnostic: bool = False,
187
verbose: int = 1,
188
exclude_classes_by_name: Optional[List[str]] = None,
189
exclude_classes_by_id: Optional[List[int]] = None,
190
):
191
"""
192
Perform predictions on FiftyOne datasets with automatic result integration.
193
194
Parameters:
195
- model_type (str): Detection framework type ("mmdet", "yolov5", etc.)
196
- model_path (str, optional): Path to model weights
197
- model_config_path (str, optional): Path to model config file (for MMDetection)
198
- model_confidence_threshold (float): Detection confidence threshold
199
- model_device (str, optional): Inference device ("cpu", "cuda", etc.)
200
- model_category_mapping (dict, optional): Category ID to name mapping
201
- model_category_remapping (dict, optional): Category remapping after inference
202
- dataset_json_path (str): Path to COCO format dataset JSON
203
- image_dir (str): Directory containing dataset images
204
- no_standard_prediction (bool): Skip standard (full image) prediction
205
- no_sliced_prediction (bool): Skip sliced prediction
206
- image_size (int, optional): Input image size for inference
207
- slice_height (int): Slice height for large images
208
- slice_width (int): Slice width for large images
209
- overlap_height_ratio (float): Vertical slice overlap
210
- overlap_width_ratio (float): Horizontal slice overlap
211
- postprocess_type (str): Postprocessing method
212
- postprocess_match_metric (str): Overlap metric for combining
213
- postprocess_match_threshold (float): Overlap threshold
214
- postprocess_class_agnostic (bool): Class-agnostic postprocessing
215
- verbose (int): Verbosity level
216
- exclude_classes_by_name (List[str], optional): Class names to exclude
217
- exclude_classes_by_id (List[int], optional): Class IDs to exclude
218
219
Returns:
220
FiftyOne dataset with predictions integrated
221
"""
222
```
223
224
### Utility Functions
225
226
```python { .api }
227
def filter_predictions(
228
object_prediction_list: List[ObjectPrediction],
229
exclude_classes_by_name: Optional[List[str]] = None,
230
exclude_classes_by_id: Optional[List[int]] = None
231
) -> List[ObjectPrediction]:
232
"""
233
Filter predictions by excluding specified classes.
234
235
Parameters:
236
- object_prediction_list: List of ObjectPrediction instances
237
- exclude_classes_by_name: Class names to exclude
238
- exclude_classes_by_id: Class IDs to exclude
239
240
Returns:
241
List of filtered ObjectPrediction instances
242
"""
243
```
244
245
## Usage Examples
246
247
### Basic Sliced Inference
248
249
```python
250
from sahi import AutoDetectionModel, get_sliced_prediction
251
252
# Load model
253
model = AutoDetectionModel.from_pretrained(
254
model_type='ultralytics',
255
model_path='yolov8n.pt',
256
confidence_threshold=0.3
257
)
258
259
# Perform sliced inference
260
result = get_sliced_prediction(
261
image="large_image.jpg",
262
detection_model=model,
263
slice_height=640,
264
slice_width=640,
265
overlap_height_ratio=0.2,
266
overlap_width_ratio=0.2
267
)
268
269
print(f"Found {len(result.object_prediction_list)} objects")
270
```
271
272
### Advanced Postprocessing Configuration
273
274
```python
275
from sahi.predict import get_sliced_prediction
276
from sahi.postprocess.combine import GreedyNMMPostprocess
277
278
# Custom postprocessing
279
postprocess = GreedyNMMPostprocess(
280
match_threshold=0.5,
281
match_metric="IOU",
282
class_agnostic=False
283
)
284
285
result = get_sliced_prediction(
286
image="image.jpg",
287
detection_model=model,
288
slice_height=512,
289
slice_width=512,
290
postprocess=postprocess,
291
verbose=2
292
)
293
```
294
295
### Batch Processing with Export
296
297
```python
298
from sahi.predict import predict
299
300
# Process entire directory with exports
301
predict(
302
model_type="ultralytics",
303
model_path="yolov8n.pt",
304
source="images/",
305
slice_height=640,
306
slice_width=640,
307
export_visual=True,
308
export_crop=True,
309
export_format="coco",
310
project="results",
311
name="experiment_1"
312
)
313
```
314
315
### Video Processing
316
317
```python
318
# Process video with frame skipping
319
predict(
320
model_type="ultralytics",
321
model_path="yolov8n.pt",
322
source="video.mp4",
323
frame_skip_interval=5, # Process every 5th frame
324
slice_height=640,
325
slice_width=640,
326
export_visual=True
327
)
328
```
329
330
### Class Filtering
331
332
```python
333
# Exclude specific classes
334
result = get_sliced_prediction(
335
image="image.jpg",
336
detection_model=model,
337
exclude_classes_by_name=["person", "bicycle"],
338
exclude_classes_by_id=[2, 3, 5]
339
)
340
```