0
# Model Integration
1
2
SAHI provides a unified interface for loading and using detection models from various deep learning frameworks. The `AutoDetectionModel` factory class automatically handles framework-specific implementations while providing consistent APIs.
3
4
## Capabilities
5
6
### AutoDetectionModel Factory
7
8
The main entry point for loading detection models from different frameworks. Automatically selects the appropriate model wrapper based on the `model_type` parameter.
9
10
```python { .api }
11
class AutoDetectionModel:
12
@staticmethod
13
def from_pretrained(
14
model_type: str,
15
model_path: Optional[str] = None,
16
model: Optional[Any] = None,
17
config_path: Optional[str] = None,
18
device: Optional[str] = None,
19
mask_threshold: float = 0.5,
20
confidence_threshold: float = 0.3,
21
category_mapping: Optional[Dict] = None,
22
category_remapping: Optional[Dict] = None,
23
load_at_init: bool = True,
24
image_size: Optional[int] = None,
25
**kwargs,
26
) -> DetectionModel:
27
"""
28
Load a DetectionModel from given path and model type.
29
30
Parameters:
31
- model_type (str): Framework name ("ultralytics", "mmdet", "detectron2", "huggingface", "torchvision", "yolov5", "roboflow", "rtdetr")
32
- model_path (str, optional): Path to model weights file
33
- model (Any, optional): Pre-initialized model instance
34
- config_path (str, optional): Path to model config file (for MMDetection)
35
- device (str, optional): Device specification ("cpu", "cuda", "cuda:0", etc.)
36
- mask_threshold (float): Threshold for mask predictions (0-1)
37
- confidence_threshold (float): Minimum confidence for detections (0-1)
38
- category_mapping (Dict, optional): Map category IDs to names
39
- category_remapping (Dict, optional): Remap category names to new IDs
40
- load_at_init (bool): Whether to load model weights at initialization
41
- image_size (int, optional): Input image size for inference
42
43
Returns:
44
DetectionModel: Framework-specific model wrapper
45
"""
46
```
47
48
### Supported Model Types
49
50
SAHI supports the following detection frameworks:
51
52
```python { .api }
53
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
54
"ultralytics": "UltralyticsDetectionModel",
55
"rtdetr": "RTDetrDetectionModel",
56
"mmdet": "MmdetDetectionModel",
57
"yolov5": "Yolov5DetectionModel",
58
"detectron2": "Detectron2DetectionModel",
59
"huggingface": "HuggingfaceDetectionModel",
60
"torchvision": "TorchVisionDetectionModel",
61
"roboflow": "RoboflowDetectionModel",
62
}
63
64
ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]
65
```
66
67
### Base DetectionModel Interface
68
69
All model integrations inherit from the base `DetectionModel` class, providing consistent APIs across frameworks.
70
71
```python { .api }
72
class DetectionModel:
73
def __init__(
74
self,
75
model_path: Optional[str] = None,
76
model: Optional[Any] = None,
77
config_path: Optional[str] = None,
78
device: Optional[str] = None,
79
mask_threshold: float = 0.5,
80
confidence_threshold: float = 0.3,
81
category_mapping: Optional[Dict] = None,
82
category_remapping: Optional[Dict] = None,
83
load_at_init: bool = True,
84
image_size: Optional[int] = None,
85
): ...
86
87
def load_model(self): ...
88
def set_model(self, model: Any): ...
89
def set_device(self, device: str): ...
90
def perform_inference(self, image: np.ndarray) -> List: ...
91
def convert_original_predictions(
92
self,
93
shift_amount: Optional[List[int]] = [0, 0],
94
full_shape: Optional[List[int]] = None,
95
) -> ObjectPrediction: ...
96
```
97
98
### Framework-Specific Models
99
100
#### Ultralytics (YOLO) Integration
101
102
```python { .api }
103
class UltralyticsDetectionModel(DetectionModel):
104
"""
105
Ultralytics YOLO model wrapper for YOLOv8, YOLOv11, and other Ultralytics models.
106
Supports both detection and segmentation models.
107
"""
108
```
109
110
#### MMDetection Integration
111
112
```python { .api }
113
class MmdetDetectionModel(DetectionModel):
114
"""
115
MMDetection framework integration supporting a wide range of detection
116
and segmentation models including Faster R-CNN, Mask R-CNN, RetinaNet, etc.
117
"""
118
```
119
120
#### Detectron2 Integration
121
122
```python { .api }
123
class Detectron2DetectionModel(DetectionModel):
124
"""
125
Facebook Detectron2 framework integration for state-of-the-art
126
object detection and instance segmentation models.
127
"""
128
```
129
130
#### HuggingFace Transformers Integration
131
132
```python { .api }
133
class HuggingfaceDetectionModel(DetectionModel):
134
"""
135
HuggingFace Transformers integration for transformer-based detection models
136
like DETR, RT-DETR, and other vision transformer architectures.
137
"""
138
```
139
140
#### TorchVision Integration
141
142
```python { .api }
143
class TorchVisionDetectionModel(DetectionModel):
144
"""
145
PyTorch TorchVision integration for official PyTorch detection models
146
including Faster R-CNN, Mask R-CNN, RetinaNet, and SSD.
147
"""
148
```
149
150
#### YOLOv5 Integration
151
152
```python { .api }
153
class Yolov5DetectionModel(DetectionModel):
154
"""
155
YOLOv5 model integration for Ultralytics YOLOv5 models with
156
custom loading and inference pipeline.
157
"""
158
```
159
160
#### Roboflow Integration
161
162
```python { .api }
163
class RoboflowDetectionModel(DetectionModel):
164
"""
165
Roboflow platform integration for deploying and using models
166
trained on the Roboflow platform.
167
"""
168
```
169
170
#### RT-DETR Integration
171
172
```python { .api }
173
class RTDetrDetectionModel(DetectionModel):
174
"""
175
RT-DETR (Real-Time Detection Transformer) model integration
176
for fast transformer-based object detection.
177
"""
178
```
179
180
## Usage Examples
181
182
### Loading Different Model Types
183
184
```python
185
from sahi import AutoDetectionModel
186
187
# Ultralytics YOLO model
188
yolo_model = AutoDetectionModel.from_pretrained(
189
model_type='ultralytics',
190
model_path='yolov8n.pt',
191
confidence_threshold=0.25,
192
device='cuda:0'
193
)
194
195
# MMDetection model
196
mmdet_model = AutoDetectionModel.from_pretrained(
197
model_type='mmdet',
198
model_path='checkpoint.pth',
199
config_path='configs/faster_rcnn_r50_fpn_1x_coco.py',
200
confidence_threshold=0.3,
201
device='cuda:0'
202
)
203
204
# HuggingFace model
205
hf_model = AutoDetectionModel.from_pretrained(
206
model_type='huggingface',
207
model_path='facebook/detr-resnet-50',
208
confidence_threshold=0.5,
209
device='cpu'
210
)
211
212
# Detectron2 model
213
d2_model = AutoDetectionModel.from_pretrained(
214
model_type='detectron2',
215
model_path='detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl',
216
confidence_threshold=0.5,
217
device='cuda:0'
218
)
219
```
220
221
### Advanced Configuration
222
223
```python
224
# Custom category mapping
225
category_mapping = {
226
0: "person",
227
1: "bicycle",
228
2: "car",
229
3: "motorcycle"
230
}
231
232
# Category remapping for custom datasets
233
category_remapping = {
234
"person": 1,
235
"vehicle": 2
236
}
237
238
model = AutoDetectionModel.from_pretrained(
239
model_type='ultralytics',
240
model_path='custom_model.pt',
241
confidence_threshold=0.25,
242
mask_threshold=0.5,
243
category_mapping=category_mapping,
244
category_remapping=category_remapping,
245
image_size=640,
246
device='cuda:0'
247
)
248
```
249
250
### Using Pre-loaded Models
251
252
```python
253
import torch
254
from ultralytics import YOLO
255
256
# Load model externally
257
external_model = YOLO('yolov8n.pt')
258
259
# Pass to SAHI
260
sahi_model = AutoDetectionModel.from_pretrained(
261
model_type='ultralytics',
262
model=external_model,
263
confidence_threshold=0.25
264
)
265
```