or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

background-removal.mdcli.mdindex.mdsession-management.mdutilities.md

session-management.mddocs/

0

# Session Management

1

2

Model session creation and management system that provides access to 23 different AI models optimized for various background removal tasks. Sessions encapsulate model loading, GPU configuration, and prediction logic.

3

4

## Capabilities

5

6

### Session Factory

7

8

Create new model sessions with automatic provider detection and configuration.

9

10

```python { .api }

11

def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:

12

"""

13

Create a new session object based on the specified model name.

14

15

Parameters:

16

- model_name: Name of the AI model to use (default: "u2net")

17

- providers: List of execution providers (auto-detected if not specified)

18

- *args: Additional positional arguments passed to session

19

- **kwargs: Additional keyword arguments passed to session

20

21

Returns:

22

BaseSession instance for the specified model

23

24

Raises:

25

ValueError: If model_name is not found in available sessions

26

"""

27

```

28

29

**Usage Examples:**

30

31

```python

32

from rembg import new_session

33

34

# Create default U2Net session

35

session = new_session()

36

37

# Create specific model session

38

portrait_session = new_session('birefnet_portrait')

39

40

# Create session with custom providers

41

gpu_session = new_session('u2net', providers=['CUDAExecutionProvider'])

42

43

# Use session for background removal

44

from rembg import remove

45

result = remove(image, session=session)

46

```

47

48

### Base Session Class

49

50

Abstract base class that all model sessions inherit from, providing common functionality.

51

52

```python { .api }

53

class BaseSession:

54

"""Base class for managing a session with a machine learning model."""

55

56

def __init__(

57

self,

58

model_name: str,

59

sess_opts: ort.SessionOptions,

60

*args,

61

**kwargs

62

):

63

"""

64

Initialize a session instance.

65

66

Parameters:

67

- model_name: Name of the model

68

- sess_opts: ONNX Runtime session options

69

- providers: List of execution providers (optional)

70

- *args: Additional positional arguments

71

- **kwargs: Additional keyword arguments

72

"""

73

74

def normalize(

75

self,

76

img: PILImage,

77

mean: Tuple[float, float, float],

78

std: Tuple[float, float, float],

79

size: Tuple[int, int],

80

*args,

81

**kwargs

82

) -> Dict[str, np.ndarray]:

83

"""

84

Normalize input image for model inference.

85

86

Parameters:

87

- img: Input PIL image

88

- mean: RGB mean values for normalization

89

- std: RGB standard deviation values for normalization

90

- size: Target size (width, height) for resizing

91

92

Returns:

93

Dictionary with normalized image data for model input

94

"""

95

96

def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

97

"""

98

Abstract method for model prediction.

99

100

Parameters:

101

- img: Input PIL image

102

- *args: Additional positional arguments

103

- **kwargs: Additional keyword arguments

104

105

Returns:

106

List of PIL Images containing prediction masks

107

"""

108

109

@classmethod

110

def checksum_disabled(cls, *args, **kwargs) -> bool:

111

"""Check if model checksum validation is disabled via environment variable."""

112

113

@classmethod

114

def u2net_home(cls, *args, **kwargs) -> str:

115

"""Get the home directory for model storage."""

116

117

@classmethod

118

def download_models(cls, *args, **kwargs):

119

"""Abstract method for downloading model weights."""

120

121

@classmethod

122

def name(cls, *args, **kwargs) -> str:

123

"""Abstract method returning the model name."""

124

```

125

126

### Available Model Sessions

127

128

Complete list of available AI model sessions, each optimized for specific use cases.

129

130

```python { .api }

131

# General-purpose models

132

class U2netSession(BaseSession):

133

"""U-Net 2.0 general-purpose background removal."""

134

135

class U2netpSession(BaseSession):

136

"""U-Net 2.0 portrait-optimized model."""

137

138

class U2netCustomSession(BaseSession):

139

"""U-Net 2.0 with custom training."""

140

141

# Human segmentation models

142

class U2netHumanSegSession(BaseSession):

143

"""U-Net 2.0 optimized for human segmentation."""

144

145

class Unet2ClothSession(BaseSession):

146

"""U-Net 2.0 specialized for clothing segmentation."""

147

148

# BiRefNet models (high-quality)

149

class BiRefNetSessionGeneral(BaseSession):

150

"""BiRefNet general-purpose high-quality model."""

151

152

class BiRefNetSessionGeneralLite(BaseSession):

153

"""BiRefNet general-purpose lightweight model."""

154

155

class BiRefNetSessionPortrait(BaseSession):

156

"""BiRefNet optimized for portrait photography."""

157

158

class BiRefNetSessionDIS(BaseSession):

159

"""BiRefNet with DIS (Dichotomous Image Segmentation)."""

160

161

class BiRefNetSessionHRSOD(BaseSession):

162

"""BiRefNet for High-Resolution Salient Object Detection."""

163

164

class BiRefNetSessionCOD(BaseSession):

165

"""BiRefNet for Camouflaged Object Detection."""

166

167

class BiRefNetSessionMassive(BaseSession):

168

"""BiRefNet massive model for highest quality."""

169

170

# Specialized models

171

class DisSession(BaseSession):

172

"""DIS model optimized for anime/cartoon characters."""

173

174

class DisCustomSession(BaseSession):

175

"""DIS model with custom training."""

176

177

class DisSessionGeneralUse(BaseSession):

178

"""DIS model for general use cases."""

179

180

class SamSession(BaseSession):

181

"""Segment Anything Model for versatile segmentation."""

182

183

class SiluetaSession(BaseSession):

184

"""Silueta model for silhouette extraction."""

185

186

class BriaRmBgSession(BaseSession):

187

"""Bria background removal specialized model."""

188

189

class BenCustomSession(BaseSession):

190

"""Ben custom-trained model."""

191

```

192

193

### Session Registry

194

195

Access to the complete session registry and model names.

196

197

```python { .api }

198

# Dictionary mapping model names to session classes

199

sessions: Dict[str, type[BaseSession]]

200

201

# List of all available model names

202

sessions_names: List[str]

203

204

# List of all session classes

205

sessions_class: List[type[BaseSession]]

206

```

207

208

**Usage Examples:**

209

210

```python

211

from rembg.sessions import sessions, sessions_names, sessions_class

212

213

# List all available models

214

print("Available models:", sessions_names)

215

216

# Get session class by name

217

u2net_class = sessions['u2net']

218

219

# Create session instance directly

220

session = u2net_class('u2net', sess_opts)

221

```

222

223

## Model Selection Guide

224

225

### General Purpose

226

- **u2net**: Best balance of speed and quality for most images

227

- **birefnet_general**: Higher quality, slower processing

228

- **birefnet_general_lite**: Good quality, faster than full BiRefNet

229

230

### Portraits and People

231

- **u2netp**: Optimized for portrait photography

232

- **birefnet_portrait**: Highest quality for portraits

233

- **u2net_human_seg**: Full human body segmentation

234

235

### Specialized Use Cases

236

- **dis_anime**: Anime and cartoon characters

237

- **u2net_cloth_seg**: Clothing and fashion photography

238

- **sam**: Versatile segmentation for complex scenes

239

- **silueta**: Clean silhouette extraction

240

241

### High Quality

242

- **birefnet_massive**: Highest quality, slowest processing

243

- **birefnet_hrsod**: High-resolution salient object detection

244

- **bria_rmbg**: Commercial-grade background removal

245

246

## GPU Configuration

247

248

Sessions automatically detect and configure GPU acceleration:

249

250

```python

251

# GPU providers are auto-detected based on availability:

252

# - CUDAExecutionProvider (NVIDIA GPUs)

253

# - ROCMExecutionProvider (AMD GPUs)

254

# - CPUExecutionProvider (fallback)

255

256

# Manual provider specification

257

session = new_session('u2net', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

258

```

259

260

## Environment Variables

261

262

- `OMP_NUM_THREADS`: Set number of threads for CPU processing

263

- `MODEL_CHECKSUM_DISABLED`: Disable model file checksum validation

264

- `U2NET_HOME`: Custom directory for model storage (default: ~/.u2net)