0
# Model Types and Utilities
1
2
Enumerations for supported model classes and utility functions for discovering available models and their capabilities. These components help with model selection, validation, and programmatic access to FlagEmbedding's model ecosystem.
3
4
## Capabilities
5
6
### EmbedderModelClass Enumeration
7
8
Enumeration defining the available embedder model architecture classes, used for programmatic model selection and validation.
9
10
```python { .api }
11
from enum import Enum
12
13
class EmbedderModelClass(Enum):
14
"""Enumeration of available embedder model classes."""
15
16
ENCODER_ONLY_BASE = "encoder-only-base"
17
"""Standard encoder-only models (BERT-like architectures)."""
18
19
ENCODER_ONLY_M3 = "encoder-only-m3"
20
"""BGE-M3 specialized encoder models with multi-vector support."""
21
22
DECODER_ONLY_BASE = "decoder-only-base"
23
"""Standard decoder-only models (LLM-like architectures)."""
24
25
DECODER_ONLY_ICL = "decoder-only-icl"
26
"""In-context learning decoder models."""
27
```
28
29
### RerankerModelClass Enumeration
30
31
Enumeration defining the available reranker model architecture classes, used for programmatic reranker selection and validation.
32
33
```python { .api }
34
from enum import Enum
35
36
class RerankerModelClass(Enum):
37
"""Enumeration of available reranker model classes."""
38
39
ENCODER_ONLY_BASE = "encoder-only-base"
40
"""Standard encoder-only rerankers (cross-encoder architecture)."""
41
42
DECODER_ONLY_BASE = "decoder-only-base"
43
"""Standard decoder-only rerankers (LLM-based)."""
44
45
DECODER_ONLY_LAYERWISE = "decoder-only-layerwise"
46
"""Layer-wise processing decoder rerankers."""
47
48
DECODER_ONLY_LIGHTWEIGHT = "decoder-only-lightweight"
49
"""Lightweight decoder rerankers for efficiency."""
50
```
51
52
### PoolingMethod Enumeration
53
54
Enumeration of pooling strategies available for embedders, determining how token representations are combined into sentence embeddings.
55
56
```python { .api }
57
from enum import Enum
58
59
class PoolingMethod(Enum):
60
"""Enumeration of pooling methods for embedders."""
61
62
LAST_TOKEN = "last_token"
63
"""Use the last token representation (common for decoder-only models)."""
64
65
CLS = "cls"
66
"""Use the CLS token representation (common for encoder-only models)."""
67
68
MEAN = "mean"
69
"""Use mean pooling across all tokens."""
70
```
71
72
### Model Discovery Utilities
73
74
Utility functions for discovering available models and their capabilities programmatically.
75
76
```python { .api }
77
def support_model_list() -> List[str]:
78
"""
79
Get list of all supported model names across all model types.
80
81
Returns:
82
List of all supported model names that can be used with
83
FlagAutoModel and FlagAutoReranker
84
"""
85
86
def support_native_bge_model_list() -> List[str]:
87
"""
88
Get list of native BGE model names specifically supported.
89
90
Returns:
91
List of BGE model names with optimized support
92
"""
93
```
94
95
## Usage Examples
96
97
### Programmatic Model Selection
98
99
```python
100
from FlagEmbedding import EmbedderModelClass, RerankerModelClass, FlagAutoModel
101
102
# Use enumeration for type-safe model selection
103
model_class = EmbedderModelClass.ENCODER_ONLY_BASE
104
105
embedder = FlagAutoModel.from_finetuned(
106
'bge-large-en-v1.5',
107
model_class=model_class, # Explicit model class
108
use_fp16=True
109
)
110
111
# Check model class programmatically
112
if model_class == EmbedderModelClass.ENCODER_ONLY_M3:
113
print("Using BGE-M3 specialized model")
114
elif model_class == EmbedderModelClass.DECODER_ONLY_BASE:
115
print("Using LLM-based embedder")
116
```
117
118
### Model Discovery and Validation
119
120
```python
121
from FlagEmbedding import support_model_list, support_native_bge_model_list
122
123
# Get all supported models
124
all_models = support_model_list()
125
print(f"Total supported models: {len(all_models)}")
126
127
# Get BGE-specific models
128
bge_models = support_native_bge_model_list()
129
print(f"Native BGE models: {len(bge_models)}")
130
131
# Validate model availability
132
def is_model_supported(model_name: str) -> bool:
133
return model_name in support_model_list()
134
135
# Check specific models
136
test_models = ['bge-large-en-v1.5', 'custom-model', 'e5-large-v2']
137
for model in test_models:
138
status = "✓" if is_model_supported(model) else "✗"
139
print(f"{status} {model}")
140
```
141
142
### Dynamic Model Configuration
143
144
```python
145
from FlagEmbedding import EmbedderModelClass, PoolingMethod, FlagAutoModel
146
147
# Configuration based on model type
148
def get_optimal_config(model_class: EmbedderModelClass) -> dict:
149
"""Get optimal configuration for different model classes."""
150
151
if model_class == EmbedderModelClass.ENCODER_ONLY_BASE:
152
return {
153
'pooling_method': PoolingMethod.CLS.value,
154
'batch_size': 256,
155
'use_fp16': True
156
}
157
elif model_class == EmbedderModelClass.ENCODER_ONLY_M3:
158
return {
159
'pooling_method': PoolingMethod.CLS.value,
160
'batch_size': 128, # Smaller for M3
161
'return_dense': True,
162
'return_sparse': False
163
}
164
elif model_class == EmbedderModelClass.DECODER_ONLY_BASE:
165
return {
166
'pooling_method': PoolingMethod.LAST_TOKEN.value,
167
'batch_size': 64, # Smaller for LLM
168
'use_fp16': True
169
}
170
else:
171
return {}
172
173
# Apply configuration dynamically
174
model_class = EmbedderModelClass.ENCODER_ONLY_BASE
175
config = get_optimal_config(model_class)
176
177
embedder = FlagAutoModel.from_finetuned(
178
'bge-large-en-v1.5',
179
model_class=model_class,
180
**config
181
)
182
```
183
184
### Model Type Detection
185
186
```python
187
from FlagEmbedding import support_model_list, EmbedderModelClass
188
189
def detect_model_type(model_name: str) -> EmbedderModelClass:
190
"""Detect model type based on model name patterns."""
191
192
model_name_lower = model_name.lower()
193
194
if 'bge-m3' in model_name_lower:
195
return EmbedderModelClass.ENCODER_ONLY_M3
196
elif any(llm_name in model_name_lower for llm_name in ['mistral', 'qwen', 'gemma']):
197
return EmbedderModelClass.DECODER_ONLY_BASE
198
elif 'icl' in model_name_lower:
199
return EmbedderModelClass.DECODER_ONLY_ICL
200
else:
201
return EmbedderModelClass.ENCODER_ONLY_BASE
202
203
# Test model type detection
204
test_models = [
205
'bge-large-en-v1.5',
206
'bge-m3',
207
'e5-mistral-7b-instruct',
208
'bge-en-icl'
209
]
210
211
for model in test_models:
212
detected_type = detect_model_type(model)
213
print(f"{model} -> {detected_type.value}")
214
```
215
216
### Reranker Type Selection
217
218
```python
219
from FlagEmbedding import RerankerModelClass, FlagAutoReranker
220
221
def select_reranker_by_requirements(
222
speed_priority: bool = False,
223
accuracy_priority: bool = False,
224
resource_constrained: bool = False
225
) -> RerankerModelClass:
226
"""Select optimal reranker class based on requirements."""
227
228
if resource_constrained:
229
return RerankerModelClass.DECODER_ONLY_LIGHTWEIGHT
230
elif speed_priority:
231
return RerankerModelClass.ENCODER_ONLY_BASE
232
elif accuracy_priority:
233
return RerankerModelClass.DECODER_ONLY_BASE
234
else:
235
return RerankerModelClass.DECODER_ONLY_LAYERWISE # Balanced
236
237
# Example usage scenarios
238
scenarios = [
239
{'speed_priority': True},
240
{'accuracy_priority': True},
241
{'resource_constrained': True},
242
{} # Default balanced
243
]
244
245
for i, scenario in enumerate(scenarios):
246
reranker_class = select_reranker_by_requirements(**scenario)
247
print(f"Scenario {i+1}: {reranker_class.value}")
248
```
249
250
### Model Compatibility Checking
251
252
```python
253
from FlagEmbedding import EmbedderModelClass, PoolingMethod
254
255
def check_model_compatibility(
256
model_class: EmbedderModelClass,
257
pooling_method: str
258
) -> bool:
259
"""Check if pooling method is compatible with model class."""
260
261
compatible_combinations = {
262
EmbedderModelClass.ENCODER_ONLY_BASE: [PoolingMethod.CLS.value, PoolingMethod.MEAN.value],
263
EmbedderModelClass.ENCODER_ONLY_M3: [PoolingMethod.CLS.value],
264
EmbedderModelClass.DECODER_ONLY_BASE: [PoolingMethod.LAST_TOKEN.value],
265
EmbedderModelClass.DECODER_ONLY_ICL: [PoolingMethod.LAST_TOKEN.value]
266
}
267
268
return pooling_method in compatible_combinations.get(model_class, [])
269
270
# Test compatibility
271
test_cases = [
272
(EmbedderModelClass.ENCODER_ONLY_BASE, PoolingMethod.CLS.value),
273
(EmbedderModelClass.ENCODER_ONLY_BASE, PoolingMethod.LAST_TOKEN.value), # Invalid
274
(EmbedderModelClass.DECODER_ONLY_BASE, PoolingMethod.LAST_TOKEN.value),
275
(EmbedderModelClass.ENCODER_ONLY_M3, PoolingMethod.MEAN.value) # Invalid
276
]
277
278
for model_class, pooling in test_cases:
279
compatible = check_model_compatibility(model_class, pooling)
280
status = "✓" if compatible else "✗"
281
print(f"{status} {model_class.value} + {pooling}")
282
```
283
284
### Advanced Model Registry
285
286
```python
287
from FlagEmbedding import support_model_list, EmbedderModelClass, RerankerModelClass
288
from typing import Dict, List
289
290
class ModelRegistry:
291
"""Advanced model registry with categorization and metadata."""
292
293
def __init__(self):
294
self.supported_models = support_model_list()
295
self._build_registry()
296
297
def _build_registry(self):
298
"""Build categorized model registry."""
299
self.registry = {
300
'embedders': {
301
'bge': [m for m in self.supported_models if 'bge' in m.lower() and 'reranker' not in m.lower()],
302
'e5': [m for m in self.supported_models if 'e5' in m.lower()],
303
'gte': [m for m in self.supported_models if 'gte' in m.lower()]
304
},
305
'rerankers': [m for m in self.supported_models if 'reranker' in m.lower()]
306
}
307
308
def get_models_by_family(self, family: str) -> List[str]:
309
"""Get models by family (bge, e5, gte)."""
310
return self.registry['embedders'].get(family, [])
311
312
def get_rerankers(self) -> List[str]:
313
"""Get all reranker models."""
314
return self.registry['rerankers']
315
316
def recommend_model(self, task_type: str, performance_tier: str) -> str:
317
"""Recommend model based on task and performance requirements."""
318
recommendations = {
319
('embedding', 'high'): 'bge-large-en-v1.5',
320
('embedding', 'medium'): 'bge-base-en-v1.5',
321
('embedding', 'fast'): 'bge-small-en-v1.5',
322
('reranking', 'high'): 'bge-reranker-large',
323
('reranking', 'medium'): 'bge-reranker-base',
324
('reranking', 'fast'): 'bge-reranker-v2.5-gemma2-lightweight'
325
}
326
return recommendations.get((task_type, performance_tier))
327
328
# Usage
329
registry = ModelRegistry()
330
331
print("BGE Models:", registry.get_models_by_family('bge'))
332
print("E5 Models:", registry.get_models_by_family('e5'))
333
print("Rerankers:", registry.get_rerankers())
334
335
# Get recommendations
336
embedding_model = registry.recommend_model('embedding', 'high')
337
reranking_model = registry.recommend_model('reranking', 'medium')
338
print(f"Recommended embedding model: {embedding_model}")
339
print(f"Recommended reranking model: {reranking_model}")
340
```
341
342
## Model Categories
343
344
### BGE Models (BAAI General Embedding)
345
- **English**: bge-large-en-v1.5, bge-base-en-v1.5, bge-small-en-v1.5
346
- **Chinese**: bge-large-zh-v1.5, bge-base-zh-v1.5, bge-small-zh-v1.5
347
- **Multilingual**: bge-multilingual-gemma2
348
- **Specialized**: bge-m3 (multi-vector), bge-en-icl (in-context learning)
349
350
### E5 Models (Text Embeddings by Text-to-Text)
351
- **Standard**: e5-large-v2, e5-base-v2, e5-small-v2
352
- **Multilingual**: multilingual-e5-large, multilingual-e5-base, multilingual-e5-small
353
- **LLM-based**: e5-mistral-7b-instruct
354
355
### GTE Models (General Text Embeddings)
356
- **English**: gte-large-en-v1.5, gte-base-en-v1.5
357
- **Chinese**: gte-large-zh, gte-base-zh, gte-small-zh
358
- **LLM-based**: gte-Qwen2-7B-instruct, gte-Qwen2-1.5B-instruct
359
360
### Reranker Models
361
- **Cross-encoder**: bge-reranker-base, bge-reranker-large
362
- **LLM-based**: bge-reranker-v2-gemma, bge-reranker-v2-m3
363
- **Specialized**: bge-reranker-v2-minicpm-layerwise, bge-reranker-v2.5-gemma2-lightweight
364
365
## Types
366
367
```python { .api }
368
from enum import Enum
369
from typing import List, Literal
370
371
# Model class enumerations
372
EmbedderClass = EmbedderModelClass
373
RerankerClass = RerankerModelClass
374
PoolingStrategy = PoolingMethod
375
376
# Model selection types
377
ModelName = str
378
ModelFamily = Literal["bge", "e5", "gte"]
379
TaskType = Literal["embedding", "reranking"]
380
PerformanceTier = Literal["high", "medium", "fast"]
381
382
# Utility function types
383
ModelList = List[str]
384
ModelValidator = Callable[[str], bool]
385
```