0
# Utilities
1
2
SAHI provides comprehensive utility functions for computer vision operations, framework integrations, file I/O, and compatibility across different deep learning ecosystems. These utilities support the core functionality and provide additional convenience functions.
3
4
## Capabilities
5
6
### Computer Vision Utilities
7
8
Core computer vision operations including image reading, visualization, and format conversions.
9
10
```python { .api }
11
def read_image_as_pil(image_path: str) -> Image.Image:
12
"""
13
Read image as PIL Image object.
14
15
Parameters:
16
- image_path (str): Path to image file
17
18
Returns:
19
Image.Image: PIL Image object
20
"""
21
22
def visualize_object_predictions(
23
image: np.ndarray,
24
object_prediction_list: List[ObjectPrediction],
25
rect_th: int = 3,
26
text_size: float = 3,
27
text_th: float = 3,
28
color: tuple = None,
29
hide_labels: bool = False,
30
hide_conf: bool = False,
31
output_dir: Optional[str] = None,
32
file_name: Optional[str] = "prediction_visual",
33
) -> np.ndarray:
34
"""
35
Visualize object predictions on image with bounding boxes and labels.
36
37
Parameters:
38
- image (np.ndarray): Input image array
39
- object_prediction_list: List of ObjectPrediction instances
40
- rect_th (int): Rectangle thickness for bounding boxes
41
- text_size (float): Text size for labels
42
- text_th (float): Text thickness
43
- color (tuple, optional): Custom color for all boxes (BGR format)
44
- hide_labels (bool): Hide class labels
45
- hide_conf (bool): Hide confidence scores
46
- output_dir (str, optional): Directory to save visualization
47
- file_name (str): Name for saved visualization file
48
49
Returns:
50
np.ndarray: Visualized image with annotations
51
"""
52
53
def crop_object_predictions(
54
image: np.ndarray,
55
object_prediction_list: List[ObjectPrediction],
56
output_dir: str,
57
file_name: str = "prediction_visual",
58
export_format: str = "deepcrop",
59
) -> Dict:
60
"""
61
Crop detected objects from image and save individually.
62
63
Parameters:
64
- image (np.ndarray): Source image array
65
- object_prediction_list: List of ObjectPrediction instances
66
- output_dir (str): Directory for saving cropped images
67
- file_name (str): Base name for cropped files
68
- export_format (str): Export format ("deepcrop", "crop")
69
70
Returns:
71
Dict: Dictionary with crop information and file paths
72
"""
73
```
74
75
### Image Format and Conversion Utilities
76
77
```python { .api }
78
def get_coco_segmentation_from_bool_mask(bool_mask: np.ndarray) -> List[List[float]]:
79
"""
80
Convert boolean mask to COCO polygon segmentation format.
81
82
Parameters:
83
- bool_mask (np.ndarray): Boolean mask array
84
85
Returns:
86
List[List[float]]: COCO format polygon coordinates
87
"""
88
89
def get_bool_mask_from_coco_segmentation(
90
segmentation: List,
91
height: int,
92
width: int
93
) -> np.ndarray:
94
"""
95
Convert COCO segmentation to boolean mask.
96
97
Parameters:
98
- segmentation (List): COCO format polygon segmentation
99
- height (int): Mask height
100
- width (int): Mask width
101
102
Returns:
103
np.ndarray: Boolean mask array
104
"""
105
106
def get_bbox_from_coco_segmentation(segmentation: List) -> List[int]:
107
"""
108
Extract bounding box from COCO segmentation.
109
110
Parameters:
111
- segmentation (List): COCO format polygon segmentation
112
113
Returns:
114
List[int]: Bounding box [xmin, ymin, width, height]
115
"""
116
```
117
118
### Color Management
119
120
```python { .api }
121
class Colors:
122
"""
123
Color palette for consistent visualization across different plots and frameworks.
124
Provides color management for bounding boxes, labels, and visualization elements.
125
"""
126
127
def __init__(self):
128
"""Initialize color palette with predefined colors."""
129
130
def __call__(self, i: int, bgr: bool = False) -> Tuple[int, ...]:
131
"""
132
Get color for given index.
133
134
Parameters:
135
- i (int): Color index
136
- bgr (bool): Return BGR format instead of RGB
137
138
Returns:
139
Tuple[int, ...]: Color tuple (RGB or BGR)
140
"""
141
```
142
143
### File I/O Utilities
144
145
Comprehensive file operations supporting multiple formats and efficient data handling.
146
147
```python { .api }
148
def save_json(data: Dict, save_path: str):
149
"""
150
Save data as JSON file with proper formatting.
151
152
Parameters:
153
- data (Dict): Data to save
154
- save_path (str): Output file path
155
"""
156
157
def load_json(load_path: str) -> Dict:
158
"""
159
Load JSON file as dictionary.
160
161
Parameters:
162
- load_path (str): Path to JSON file
163
164
Returns:
165
Dict: Loaded data
166
"""
167
168
def save_pickle(data: Any, save_path: str):
169
"""
170
Save data as pickle file for efficient storage.
171
172
Parameters:
173
- data (Any): Data to save
174
- save_path (str): Output file path
175
"""
176
177
def load_pickle(load_path: str) -> Any:
178
"""
179
Load pickle file.
180
181
Parameters:
182
- load_path (str): Path to pickle file
183
184
Returns:
185
Any: Loaded data
186
"""
187
188
def list_files(
189
directory: str,
190
contains: Optional[List[str]] = None,
191
extensions: Optional[List[str]] = None,
192
recursive: bool = True,
193
) -> List[str]:
194
"""
195
List files in directory with filtering options.
196
197
Parameters:
198
- directory (str): Directory to search
199
- contains (List[str], optional): Substrings that filenames must contain
200
- extensions (List[str], optional): File extensions to include
201
- recursive (bool): Search subdirectories recursively
202
203
Returns:
204
List[str]: List of matching file paths
205
"""
206
207
def download_from_url(url: str, save_path: str):
208
"""
209
Download file from URL.
210
211
Parameters:
212
- url (str): URL to download from
213
- save_path (str): Local path to save file
214
"""
215
216
def import_model_class(model_class_name: str, model_type: str):
217
"""
218
Dynamically import model class based on type.
219
220
Parameters:
221
- model_class_name (str): Name of model class to import
222
- model_type (str): Model framework type
223
224
Returns:
225
Type: Imported model class
226
"""
227
```
228
229
### PyTorch Utilities
230
231
Utilities for PyTorch tensor operations and device management.
232
233
```python { .api }
234
def empty_cuda_cache():
235
"""Clear CUDA memory cache to free up GPU memory."""
236
237
def to_float_tensor(image: Union[np.ndarray, Image.Image]) -> torch.Tensor:
238
"""
239
Convert image to PyTorch float tensor.
240
241
Parameters:
242
- image: Input image (numpy array or PIL Image)
243
244
Returns:
245
torch.Tensor: Float tensor in CHW format
246
"""
247
248
def torch_to_numpy(tensor: torch.Tensor) -> np.ndarray:
249
"""
250
Convert PyTorch tensor to numpy array.
251
252
Parameters:
253
- tensor (torch.Tensor): Input tensor
254
255
Returns:
256
np.ndarray: Numpy array
257
"""
258
259
def select_device(device: Optional[str] = None) -> torch.device:
260
"""
261
Select appropriate PyTorch device for inference.
262
263
Parameters:
264
- device (str, optional): Device specification ("cpu", "cuda", "mps", etc.)
265
266
Returns:
267
torch.device: Selected PyTorch device
268
"""
269
```
270
271
### Import and Environment Utilities
272
273
Utilities for checking dependencies and managing package imports.
274
275
```python { .api }
276
def is_available(package: str) -> bool:
277
"""
278
Check if package is available for import.
279
280
Parameters:
281
- package (str): Package name to check
282
283
Returns:
284
bool: True if package is available
285
"""
286
287
def check_requirements(
288
requirements: List[str],
289
raise_exception: bool = True
290
):
291
"""
292
Verify that required packages are installed.
293
294
Parameters:
295
- requirements (List[str]): List of required package names
296
- raise_exception (bool): Whether to raise exception if packages missing
297
298
Raises:
299
ImportError: If required packages are missing and raise_exception=True
300
"""
301
302
def get_package_info(package_name: str) -> Dict[str, str]:
303
"""
304
Get information about installed package.
305
306
Parameters:
307
- package_name (str): Name of package to query
308
309
Returns:
310
Dict[str, str]: Package information (version, location, etc.)
311
"""
312
313
def print_environment_info():
314
"""
315
Print comprehensive environment and dependency information.
316
Includes Python version, PyTorch version, CUDA availability,
317
system information, and installed package versions.
318
"""
319
```
320
321
### Framework-Specific Utilities
322
323
Utilities for specific deep learning framework integrations.
324
325
```python { .api }
326
# Detectron2 utilities
327
def convert_detectron2_bbox_format(bbox: List) -> List:
328
"""Convert Detectron2 bbox format to standard format."""
329
330
def convert_detectron2_mask_format(mask: np.ndarray) -> np.ndarray:
331
"""Convert Detectron2 mask format to standard format."""
332
333
# MMDetection utilities
334
def convert_mmdet_bbox_format(bbox: List) -> List:
335
"""Convert MMDetection bbox format to standard format."""
336
337
def convert_mmdet_mask_format(mask: np.ndarray) -> np.ndarray:
338
"""Convert MMDetection mask format to standard format."""
339
340
# TorchVision utilities
341
def convert_torchvision_bbox_format(bbox: torch.Tensor) -> List:
342
"""Convert TorchVision bbox format to standard format."""
343
344
# RT-DETR utilities
345
def convert_rtdetr_output_format(outputs: Dict) -> List:
346
"""Convert RT-DETR output format to standard ObjectPrediction format."""
347
```
348
349
### File Path and Video Utilities
350
351
```python { .api }
352
class Path:
353
"""Enhanced path handling with additional convenience methods."""
354
355
def __init__(self, path: str):
356
"""Initialize path handler."""
357
358
@property
359
def suffix(self) -> str:
360
"""Get file extension."""
361
362
@property
363
def stem(self) -> str:
364
"""Get filename without extension."""
365
366
def increment_path(path: str, exist_ok: bool = False) -> str:
367
"""
368
Increment file path to avoid overwrites.
369
370
Parameters:
371
- path (str): Original path
372
- exist_ok (bool): Whether existing path is acceptable
373
374
Returns:
375
str: Incremented path (e.g., "file_1.txt", "file_2.txt")
376
"""
377
378
def get_video_reader(video_path: str):
379
"""
380
Get video reader object for frame-by-frame processing.
381
382
Parameters:
383
- video_path (str): Path to video file
384
385
Returns:
386
Video reader object
387
"""
388
```
389
390
### Constants
391
392
```python { .api }
393
# Supported file extensions
394
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.tiff', '.bmp', '.gif']
395
VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.wmv']
396
IMAGE_EXTENSIONS_LOSSLESS = ['.png', '.tiff', '.bmp']
397
IMAGE_EXTENSIONS_LOSSY = ['.jpg', '.jpeg']
398
```
399
400
## Usage Examples
401
402
### Image Processing and Visualization
403
404
```python
405
from sahi.utils.cv import read_image_as_pil, visualize_object_predictions
406
from sahi import get_sliced_prediction
407
408
# Read image
409
image = read_image_as_pil("input_image.jpg")
410
411
# Get predictions
412
result = get_sliced_prediction(
413
image="input_image.jpg",
414
detection_model=model
415
)
416
417
# Visualize predictions
418
visualized = visualize_object_predictions(
419
image=np.array(image),
420
object_prediction_list=result.object_prediction_list,
421
rect_th=3,
422
text_size=1.0,
423
hide_conf=False,
424
output_dir="visualizations/",
425
file_name="result"
426
)
427
```
428
429
### File Operations
430
431
```python
432
from sahi.utils.file import save_json, load_json, list_files
433
434
# Save prediction results
435
predictions_data = {
436
"predictions": [pred.json for pred in result.object_prediction_list],
437
"metadata": {"model": "yolov8n", "confidence": 0.25}
438
}
439
save_json(predictions_data, "predictions.json")
440
441
# Load data
442
loaded_data = load_json("predictions.json")
443
444
# List image files
445
image_files = list_files(
446
directory="dataset/",
447
extensions=[".jpg", ".png"],
448
contains=["train", "val"],
449
recursive=True
450
)
451
print(f"Found {len(image_files)} image files")
452
```
453
454
### Environment and Dependency Management
455
456
```python
457
from sahi.utils.import_utils import is_available, check_requirements, print_environment_info
458
459
# Check if optional dependencies are available
460
if is_available("fiftyone"):
461
print("FiftyOne integration available")
462
463
if is_available("mmdet"):
464
print("MMDetection integration available")
465
466
# Verify required packages
467
try:
468
check_requirements(["torch", "torchvision", "ultralytics"])
469
print("All requirements satisfied")
470
except ImportError as e:
471
print(f"Missing requirements: {e}")
472
473
# Print full environment info
474
print_environment_info()
475
```
476
477
### PyTorch Utilities
478
479
```python
480
from sahi.utils.torch_utils import select_device, empty_cuda_cache, to_float_tensor
481
import numpy as np
482
483
# Select optimal device
484
device = select_device("cuda")
485
print(f"Using device: {device}")
486
487
# Convert image to tensor
488
image_array = np.random.rand(480, 640, 3).astype(np.uint8)
489
tensor = to_float_tensor(image_array)
490
print(f"Tensor shape: {tensor.shape}")
491
492
# Clear CUDA cache after processing
493
empty_cuda_cache()
494
```
495
496
### Custom Colors for Visualization
497
498
```python
499
from sahi.utils.cv import Colors, visualize_object_predictions
500
501
# Initialize color palette
502
colors = Colors()
503
504
# Get specific colors
505
red = colors(0) # First color in palette
506
blue = colors(1) # Second color
507
green = colors(2) # Third color
508
509
# Use custom color for visualization
510
visualized = visualize_object_predictions(
511
image=image_array,
512
object_prediction_list=predictions,
513
color=(0, 255, 0), # Custom green color
514
rect_th=2,
515
text_size=0.8
516
)
517
```
518
519
### File Path Management
520
521
```python
522
from sahi.utils.file import increment_path, Path
523
524
# Avoid overwriting existing files
525
output_path = increment_path("results/experiment.json")
526
print(f"Using path: {output_path}") # e.g., "results/experiment_1.json"
527
528
# Enhanced path handling
529
path = Path("dataset/images/sample.jpg")
530
print(f"Extension: {path.suffix}") # ".jpg"
531
print(f"Filename: {path.stem}") # "sample"
532
```
533
534
### Format Conversions
535
536
```python
537
from sahi.utils.cv import (
538
get_coco_segmentation_from_bool_mask,
539
get_bool_mask_from_coco_segmentation,
540
get_bbox_from_coco_segmentation
541
)
542
543
# Create boolean mask
544
bool_mask = np.random.rand(100, 100) > 0.5
545
546
# Convert to COCO format
547
coco_segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
548
549
# Convert back to boolean mask
550
reconstructed_mask = get_bool_mask_from_coco_segmentation(
551
coco_segmentation, 100, 100
552
)
553
554
# Extract bounding box from segmentation
555
bbox = get_bbox_from_coco_segmentation(coco_segmentation)
556
print(f"Bounding box: {bbox}")
557
```