or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

annotation-framework.mdcli.mdcoco-integration.mdimage-slicing.mdindex.mdmodel-integration.mdpostprocessing.mdprediction-functions.mdutilities.md

model-integration.mddocs/

0

# Model Integration

1

2

SAHI provides a unified interface for loading and using detection models from various deep learning frameworks. The `AutoDetectionModel` factory class automatically handles framework-specific implementations while providing consistent APIs.

3

4

## Capabilities

5

6

### AutoDetectionModel Factory

7

8

The main entry point for loading detection models from different frameworks. Automatically selects the appropriate model wrapper based on the `model_type` parameter.

9

10

```python { .api }

11

class AutoDetectionModel:

12

@staticmethod

13

def from_pretrained(

14

model_type: str,

15

model_path: Optional[str] = None,

16

model: Optional[Any] = None,

17

config_path: Optional[str] = None,

18

device: Optional[str] = None,

19

mask_threshold: float = 0.5,

20

confidence_threshold: float = 0.3,

21

category_mapping: Optional[Dict] = None,

22

category_remapping: Optional[Dict] = None,

23

load_at_init: bool = True,

24

image_size: Optional[int] = None,

25

**kwargs,

26

) -> DetectionModel:

27

"""

28

Load a DetectionModel from given path and model type.

29

30

Parameters:

31

- model_type (str): Framework name ("ultralytics", "mmdet", "detectron2", "huggingface", "torchvision", "yolov5", "roboflow", "rtdetr")

32

- model_path (str, optional): Path to model weights file

33

- model (Any, optional): Pre-initialized model instance

34

- config_path (str, optional): Path to model config file (for MMDetection)

35

- device (str, optional): Device specification ("cpu", "cuda", "cuda:0", etc.)

36

- mask_threshold (float): Threshold for mask predictions (0-1)

37

- confidence_threshold (float): Minimum confidence for detections (0-1)

38

- category_mapping (Dict, optional): Map category IDs to names

39

- category_remapping (Dict, optional): Remap category names to new IDs

40

- load_at_init (bool): Whether to load model weights at initialization

41

- image_size (int, optional): Input image size for inference

42

43

Returns:

44

DetectionModel: Framework-specific model wrapper

45

"""

46

```

47

48

### Supported Model Types

49

50

SAHI supports the following detection frameworks:

51

52

```python { .api }

53

MODEL_TYPE_TO_MODEL_CLASS_NAME = {

54

"ultralytics": "UltralyticsDetectionModel",

55

"rtdetr": "RTDetrDetectionModel",

56

"mmdet": "MmdetDetectionModel",

57

"yolov5": "Yolov5DetectionModel",

58

"detectron2": "Detectron2DetectionModel",

59

"huggingface": "HuggingfaceDetectionModel",

60

"torchvision": "TorchVisionDetectionModel",

61

"roboflow": "RoboflowDetectionModel",

62

}

63

64

ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]

65

```

66

67

### Base DetectionModel Interface

68

69

All model integrations inherit from the base `DetectionModel` class, providing consistent APIs across frameworks.

70

71

```python { .api }

72

class DetectionModel:

73

def __init__(

74

self,

75

model_path: Optional[str] = None,

76

model: Optional[Any] = None,

77

config_path: Optional[str] = None,

78

device: Optional[str] = None,

79

mask_threshold: float = 0.5,

80

confidence_threshold: float = 0.3,

81

category_mapping: Optional[Dict] = None,

82

category_remapping: Optional[Dict] = None,

83

load_at_init: bool = True,

84

image_size: Optional[int] = None,

85

): ...

86

87

def load_model(self): ...

88

def set_model(self, model: Any): ...

89

def set_device(self, device: str): ...

90

def perform_inference(self, image: np.ndarray) -> List: ...

91

def convert_original_predictions(

92

self,

93

shift_amount: Optional[List[int]] = [0, 0],

94

full_shape: Optional[List[int]] = None,

95

) -> ObjectPrediction: ...

96

```

97

98

### Framework-Specific Models

99

100

#### Ultralytics (YOLO) Integration

101

102

```python { .api }

103

class UltralyticsDetectionModel(DetectionModel):

104

"""

105

Ultralytics YOLO model wrapper for YOLOv8, YOLOv11, and other Ultralytics models.

106

Supports both detection and segmentation models.

107

"""

108

```

109

110

#### MMDetection Integration

111

112

```python { .api }

113

class MmdetDetectionModel(DetectionModel):

114

"""

115

MMDetection framework integration supporting a wide range of detection

116

and segmentation models including Faster R-CNN, Mask R-CNN, RetinaNet, etc.

117

"""

118

```

119

120

#### Detectron2 Integration

121

122

```python { .api }

123

class Detectron2DetectionModel(DetectionModel):

124

"""

125

Facebook Detectron2 framework integration for state-of-the-art

126

object detection and instance segmentation models.

127

"""

128

```

129

130

#### HuggingFace Transformers Integration

131

132

```python { .api }

133

class HuggingfaceDetectionModel(DetectionModel):

134

"""

135

HuggingFace Transformers integration for transformer-based detection models

136

like DETR, RT-DETR, and other vision transformer architectures.

137

"""

138

```

139

140

#### TorchVision Integration

141

142

```python { .api }

143

class TorchVisionDetectionModel(DetectionModel):

144

"""

145

PyTorch TorchVision integration for official PyTorch detection models

146

including Faster R-CNN, Mask R-CNN, RetinaNet, and SSD.

147

"""

148

```

149

150

#### YOLOv5 Integration

151

152

```python { .api }

153

class Yolov5DetectionModel(DetectionModel):

154

"""

155

YOLOv5 model integration for Ultralytics YOLOv5 models with

156

custom loading and inference pipeline.

157

"""

158

```

159

160

#### Roboflow Integration

161

162

```python { .api }

163

class RoboflowDetectionModel(DetectionModel):

164

"""

165

Roboflow platform integration for deploying and using models

166

trained on the Roboflow platform.

167

"""

168

```

169

170

#### RT-DETR Integration

171

172

```python { .api }

173

class RTDetrDetectionModel(DetectionModel):

174

"""

175

RT-DETR (Real-Time Detection Transformer) model integration

176

for fast transformer-based object detection.

177

"""

178

```

179

180

## Usage Examples

181

182

### Loading Different Model Types

183

184

```python

185

from sahi import AutoDetectionModel

186

187

# Ultralytics YOLO model

188

yolo_model = AutoDetectionModel.from_pretrained(

189

model_type='ultralytics',

190

model_path='yolov8n.pt',

191

confidence_threshold=0.25,

192

device='cuda:0'

193

)

194

195

# MMDetection model

196

mmdet_model = AutoDetectionModel.from_pretrained(

197

model_type='mmdet',

198

model_path='checkpoint.pth',

199

config_path='configs/faster_rcnn_r50_fpn_1x_coco.py',

200

confidence_threshold=0.3,

201

device='cuda:0'

202

)

203

204

# HuggingFace model

205

hf_model = AutoDetectionModel.from_pretrained(

206

model_type='huggingface',

207

model_path='facebook/detr-resnet-50',

208

confidence_threshold=0.5,

209

device='cpu'

210

)

211

212

# Detectron2 model

213

d2_model = AutoDetectionModel.from_pretrained(

214

model_type='detectron2',

215

model_path='detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl',

216

confidence_threshold=0.5,

217

device='cuda:0'

218

)

219

```

220

221

### Advanced Configuration

222

223

```python

224

# Custom category mapping

225

category_mapping = {

226

0: "person",

227

1: "bicycle",

228

2: "car",

229

3: "motorcycle"

230

}

231

232

# Category remapping for custom datasets

233

category_remapping = {

234

"person": 1,

235

"vehicle": 2

236

}

237

238

model = AutoDetectionModel.from_pretrained(

239

model_type='ultralytics',

240

model_path='custom_model.pt',

241

confidence_threshold=0.25,

242

mask_threshold=0.5,

243

category_mapping=category_mapping,

244

category_remapping=category_remapping,

245

image_size=640,

246

device='cuda:0'

247

)

248

```

249

250

### Using Pre-loaded Models

251

252

```python

253

import torch

254

from ultralytics import YOLO

255

256

# Load model externally

257

external_model = YOLO('yolov8n.pt')

258

259

# Pass to SAHI

260

sahi_model = AutoDetectionModel.from_pretrained(

261

model_type='ultralytics',

262

model=external_model,

263

confidence_threshold=0.25

264

)

265

```