0
# Session Management
1
2
Model session creation and management system that provides access to 23 different AI models optimized for various background removal tasks. Sessions encapsulate model loading, GPU configuration, and prediction logic.
3
4
## Capabilities
5
6
### Session Factory
7
8
Create new model sessions with automatic provider detection and configuration.
9
10
```python { .api }
11
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
12
"""
13
Create a new session object based on the specified model name.
14
15
Parameters:
16
- model_name: Name of the AI model to use (default: "u2net")
17
- providers: List of execution providers (auto-detected if not specified)
18
- *args: Additional positional arguments passed to session
19
- **kwargs: Additional keyword arguments passed to session
20
21
Returns:
22
BaseSession instance for the specified model
23
24
Raises:
25
ValueError: If model_name is not found in available sessions
26
"""
27
```
28
29
**Usage Examples:**
30
31
```python
32
from rembg import new_session
33
34
# Create default U2Net session
35
session = new_session()
36
37
# Create specific model session
38
portrait_session = new_session('birefnet_portrait')
39
40
# Create session with custom providers
41
gpu_session = new_session('u2net', providers=['CUDAExecutionProvider'])
42
43
# Use session for background removal
44
from rembg import remove
45
result = remove(image, session=session)
46
```
47
48
### Base Session Class
49
50
Abstract base class that all model sessions inherit from, providing common functionality.
51
52
```python { .api }
53
class BaseSession:
54
"""Base class for managing a session with a machine learning model."""
55
56
def __init__(
57
self,
58
model_name: str,
59
sess_opts: ort.SessionOptions,
60
*args,
61
**kwargs
62
):
63
"""
64
Initialize a session instance.
65
66
Parameters:
67
- model_name: Name of the model
68
- sess_opts: ONNX Runtime session options
69
- providers: List of execution providers (optional)
70
- *args: Additional positional arguments
71
- **kwargs: Additional keyword arguments
72
"""
73
74
def normalize(
75
self,
76
img: PILImage,
77
mean: Tuple[float, float, float],
78
std: Tuple[float, float, float],
79
size: Tuple[int, int],
80
*args,
81
**kwargs
82
) -> Dict[str, np.ndarray]:
83
"""
84
Normalize input image for model inference.
85
86
Parameters:
87
- img: Input PIL image
88
- mean: RGB mean values for normalization
89
- std: RGB standard deviation values for normalization
90
- size: Target size (width, height) for resizing
91
92
Returns:
93
Dictionary with normalized image data for model input
94
"""
95
96
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
97
"""
98
Abstract method for model prediction.
99
100
Parameters:
101
- img: Input PIL image
102
- *args: Additional positional arguments
103
- **kwargs: Additional keyword arguments
104
105
Returns:
106
List of PIL Images containing prediction masks
107
"""
108
109
@classmethod
110
def checksum_disabled(cls, *args, **kwargs) -> bool:
111
"""Check if model checksum validation is disabled via environment variable."""
112
113
@classmethod
114
def u2net_home(cls, *args, **kwargs) -> str:
115
"""Get the home directory for model storage."""
116
117
@classmethod
118
def download_models(cls, *args, **kwargs):
119
"""Abstract method for downloading model weights."""
120
121
@classmethod
122
def name(cls, *args, **kwargs) -> str:
123
"""Abstract method returning the model name."""
124
```
125
126
### Available Model Sessions
127
128
Complete list of available AI model sessions, each optimized for specific use cases.
129
130
```python { .api }
131
# General-purpose models
132
class U2netSession(BaseSession):
133
"""U-Net 2.0 general-purpose background removal."""
134
135
class U2netpSession(BaseSession):
136
"""U-Net 2.0 portrait-optimized model."""
137
138
class U2netCustomSession(BaseSession):
139
"""U-Net 2.0 with custom training."""
140
141
# Human segmentation models
142
class U2netHumanSegSession(BaseSession):
143
"""U-Net 2.0 optimized for human segmentation."""
144
145
class Unet2ClothSession(BaseSession):
146
"""U-Net 2.0 specialized for clothing segmentation."""
147
148
# BiRefNet models (high-quality)
149
class BiRefNetSessionGeneral(BaseSession):
150
"""BiRefNet general-purpose high-quality model."""
151
152
class BiRefNetSessionGeneralLite(BaseSession):
153
"""BiRefNet general-purpose lightweight model."""
154
155
class BiRefNetSessionPortrait(BaseSession):
156
"""BiRefNet optimized for portrait photography."""
157
158
class BiRefNetSessionDIS(BaseSession):
159
"""BiRefNet with DIS (Dichotomous Image Segmentation)."""
160
161
class BiRefNetSessionHRSOD(BaseSession):
162
"""BiRefNet for High-Resolution Salient Object Detection."""
163
164
class BiRefNetSessionCOD(BaseSession):
165
"""BiRefNet for Camouflaged Object Detection."""
166
167
class BiRefNetSessionMassive(BaseSession):
168
"""BiRefNet massive model for highest quality."""
169
170
# Specialized models
171
class DisSession(BaseSession):
172
"""DIS model optimized for anime/cartoon characters."""
173
174
class DisCustomSession(BaseSession):
175
"""DIS model with custom training."""
176
177
class DisSessionGeneralUse(BaseSession):
178
"""DIS model for general use cases."""
179
180
class SamSession(BaseSession):
181
"""Segment Anything Model for versatile segmentation."""
182
183
class SiluetaSession(BaseSession):
184
"""Silueta model for silhouette extraction."""
185
186
class BriaRmBgSession(BaseSession):
187
"""Bria background removal specialized model."""
188
189
class BenCustomSession(BaseSession):
190
"""Ben custom-trained model."""
191
```
192
193
### Session Registry
194
195
Access to the complete session registry and model names.
196
197
```python { .api }
198
# Dictionary mapping model names to session classes
199
sessions: Dict[str, type[BaseSession]]
200
201
# List of all available model names
202
sessions_names: List[str]
203
204
# List of all session classes
205
sessions_class: List[type[BaseSession]]
206
```
207
208
**Usage Examples:**
209
210
```python
211
from rembg.sessions import sessions, sessions_names, sessions_class
212
213
# List all available models
214
print("Available models:", sessions_names)
215
216
# Get session class by name
217
u2net_class = sessions['u2net']
218
219
# Create session instance directly
220
session = u2net_class('u2net', sess_opts)
221
```
222
223
## Model Selection Guide
224
225
### General Purpose
226
- **u2net**: Best balance of speed and quality for most images
227
- **birefnet_general**: Higher quality, slower processing
228
- **birefnet_general_lite**: Good quality, faster than full BiRefNet
229
230
### Portraits and People
231
- **u2netp**: Optimized for portrait photography
232
- **birefnet_portrait**: Highest quality for portraits
233
- **u2net_human_seg**: Full human body segmentation
234
235
### Specialized Use Cases
236
- **dis_anime**: Anime and cartoon characters
237
- **u2net_cloth_seg**: Clothing and fashion photography
238
- **sam**: Versatile segmentation for complex scenes
239
- **silueta**: Clean silhouette extraction
240
241
### High Quality
242
- **birefnet_massive**: Highest quality, slowest processing
243
- **birefnet_hrsod**: High-resolution salient object detection
244
- **bria_rmbg**: Commercial-grade background removal
245
246
## GPU Configuration
247
248
Sessions automatically detect and configure GPU acceleration:
249
250
```python
251
# GPU providers are auto-detected based on availability:
252
# - CUDAExecutionProvider (NVIDIA GPUs)
253
# - ROCMExecutionProvider (AMD GPUs)
254
# - CPUExecutionProvider (fallback)
255
256
# Manual provider specification
257
session = new_session('u2net', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
258
```
259
260
## Environment Variables
261
262
- `OMP_NUM_THREADS`: Set number of threads for CPU processing
263
- `MODEL_CHECKSUM_DISABLED`: Disable model file checksum validation
264
- `U2NET_HOME`: Custom directory for model storage (default: ~/.u2net)