0
# Base Classes and Interfaces
1
2
Abstract base classes that define the core interface contracts for embedders and rerankers. These classes provide multi-device support, consistent API patterns, and form the foundation for all concrete implementations in FlagEmbedding.
3
4
## Capabilities
5
6
### AbsEmbedder (Abstract Embedder Base)
7
8
Base class for all embedding models providing a unified interface for encoding text into vector representations. Supports multi-device processing and consistent batch handling across different model architectures.
9
10
```python { .api }
11
from typing import Union
12
13
class AbsEmbedder:
14
def __init__(
15
self,
16
model_name_or_path: str,
17
normalize_embeddings: bool = True,
18
use_fp16: bool = True,
19
query_instruction_for_retrieval: Optional[str] = None,
20
query_instruction_format: str = "{}{}",
21
devices: Optional[Union[str, List[str]]] = None,
22
batch_size: int = 256,
23
query_max_length: int = 512,
24
passage_max_length: int = 512,
25
convert_to_numpy: bool = True,
26
**kwargs
27
):
28
"""
29
Initialize abstract embedder base class.
30
31
Args:
32
model_name_or_path: Path to model or HuggingFace model name
33
normalize_embeddings: Whether to normalize output embeddings
34
use_fp16: Use half precision for inference
35
query_instruction_for_retrieval: Instruction prepended to queries
36
query_instruction_format: Format string for instructions
37
devices: List of devices for multi-GPU inference
38
batch_size: Default batch size for encoding
39
query_max_length: Maximum query token length
40
passage_max_length: Maximum passage token length
41
convert_to_numpy: Convert outputs to numpy arrays
42
**kwargs: Additional model-specific parameters
43
"""
44
45
def encode_queries(
46
self,
47
queries: Union[str, List[str]],
48
batch_size: Optional[int] = None,
49
max_length: Optional[int] = None,
50
convert_to_numpy: Optional[bool] = None,
51
**kwargs
52
) -> Union[torch.Tensor, np.ndarray]:
53
"""
54
Encode queries for retrieval tasks.
55
56
Args:
57
queries: Single query string or list of query strings
58
batch_size: Batch size for processing (overrides default)
59
max_length: Maximum sequence length (overrides query_max_length)
60
convert_to_numpy: Convert output to numpy (overrides default)
61
**kwargs: Additional encoding parameters
62
63
Returns:
64
Query embeddings as tensor or numpy array
65
"""
66
67
def encode_corpus(
68
self,
69
corpus: Union[str, List[str]],
70
batch_size: Optional[int] = None,
71
max_length: Optional[int] = None,
72
convert_to_numpy: Optional[bool] = None,
73
**kwargs
74
) -> Union[torch.Tensor, np.ndarray]:
75
"""
76
Encode corpus documents for retrieval tasks.
77
78
Args:
79
corpus: Single document string or list of document strings
80
batch_size: Batch size for processing (overrides default)
81
max_length: Maximum sequence length (overrides passage_max_length)
82
convert_to_numpy: Convert output to numpy (overrides default)
83
**kwargs: Additional encoding parameters
84
85
Returns:
86
Corpus embeddings as tensor or numpy array
87
"""
88
89
def encode(
90
self,
91
sentences: Union[str, List[str]],
92
batch_size: Optional[int] = None,
93
max_length: Optional[int] = None,
94
convert_to_numpy: Optional[bool] = None,
95
instruction: Optional[str] = None,
96
instruction_format: Optional[str] = None,
97
**kwargs
98
) -> Union[torch.Tensor, np.ndarray]:
99
"""
100
General-purpose encoding method for any text.
101
102
Args:
103
sentences: Single sentence or list of sentences to encode
104
batch_size: Batch size for processing
105
max_length: Maximum sequence length
106
convert_to_numpy: Convert output to numpy
107
instruction: Instruction to prepend to sentences
108
instruction_format: Format string for instruction
109
**kwargs: Additional encoding parameters
110
111
Returns:
112
Text embeddings as tensor or numpy array
113
"""
114
115
def encode_single_device(
116
self,
117
sentences: List[str],
118
batch_size: int = 256,
119
max_length: int = 512,
120
convert_to_numpy: bool = True,
121
device: Optional[str] = None,
122
**kwargs
123
) -> Union[torch.Tensor, np.ndarray]:
124
"""
125
Abstract method for single-device encoding (implemented by subclasses).
126
127
Args:
128
sentences: List of sentences to encode
129
batch_size: Batch size for processing
130
max_length: Maximum sequence length
131
convert_to_numpy: Convert output to numpy
132
device: Specific device for processing
133
**kwargs: Additional encoding parameters
134
135
Returns:
136
Embeddings from single device
137
"""
138
139
def start_multi_process_pool(
140
self,
141
process_target_func: Callable
142
) -> Dict[str, Any]:
143
"""
144
Start multi-process pool for parallel processing.
145
146
Args:
147
process_target_func: Function to execute in parallel
148
149
Returns:
150
Process pool information
151
"""
152
153
@staticmethod
154
def stop_multi_process_pool(pool: Dict[str, Any]) -> None:
155
"""
156
Stop multi-process pool and clean up resources.
157
158
Args:
159
pool: Process pool to terminate
160
"""
161
```
162
163
### AbsReranker (Abstract Reranker Base)
164
165
Base class for all reranking models providing a unified interface for scoring query-document pairs. Supports multi-device processing and flexible instruction formatting.
166
167
```python { .api }
168
class AbsReranker:
169
def __init__(
170
self,
171
model_name_or_path: str,
172
use_fp16: bool = False,
173
query_instruction_for_rerank: Optional[str] = None,
174
query_instruction_format: str = "{}{}",
175
passage_instruction_for_rerank: Optional[str] = None,
176
passage_instruction_format: str = "{}{}",
177
devices: Optional[Union[str, List[str]]] = None,
178
batch_size: int = 128,
179
query_max_length: Optional[int] = None,
180
max_length: int = 512,
181
normalize: bool = False,
182
**kwargs
183
):
184
"""
185
Initialize abstract reranker base class.
186
187
Args:
188
model_name_or_path: Path to model or HuggingFace model name
189
use_fp16: Use half precision for inference
190
query_instruction_for_rerank: Instruction prepended to queries
191
query_instruction_format: Format string for query instructions
192
passage_instruction_for_rerank: Instruction prepended to passages
193
passage_instruction_format: Format string for passage instructions
194
devices: List of devices for multi-GPU inference
195
batch_size: Default batch size for scoring
196
query_max_length: Maximum query token length
197
max_length: Maximum total sequence length
198
normalize: Whether to normalize output scores
199
**kwargs: Additional model-specific parameters
200
"""
201
202
def compute_score(
203
self,
204
sentence_pairs: List[Tuple[str, str]],
205
**kwargs
206
) -> np.ndarray:
207
"""
208
Compute relevance scores for query-document pairs.
209
210
Args:
211
sentence_pairs: List of (query, document) tuples
212
**kwargs: Additional scoring parameters
213
214
Returns:
215
Array of relevance scores (higher = more relevant)
216
"""
217
218
def compute_score_single_gpu(
219
self,
220
sentence_pairs: List[Tuple[str, str]],
221
batch_size: int = 256,
222
query_max_length: Optional[int] = None,
223
max_length: int = 512,
224
normalize: bool = False,
225
device: Optional[str] = None,
226
**kwargs
227
) -> np.ndarray:
228
"""
229
Abstract method for single-GPU scoring (implemented by subclasses).
230
231
Args:
232
sentence_pairs: List of (query, document) tuples
233
batch_size: Batch size for processing
234
query_max_length: Maximum query token length
235
max_length: Maximum total sequence length
236
normalize: Whether to normalize scores
237
device: Specific device for processing
238
**kwargs: Additional scoring parameters
239
240
Returns:
241
Relevance scores from single GPU
242
"""
243
```
244
245
## Usage Examples
246
247
### Understanding the Base Class Interface
248
249
```python
250
from FlagEmbedding import FlagModel, FlagReranker
251
252
# All concrete embedders inherit from AbsEmbedder
253
embedder = FlagModel('bge-base-en-v1.5')
254
assert isinstance(embedder, AbsEmbedder) # True
255
256
# All concrete rerankers inherit from AbsReranker
257
reranker = FlagReranker('bge-reranker-base')
258
assert isinstance(reranker, AbsReranker) # True
259
260
# Base class methods are available on all implementations
261
queries = ["What is machine learning?"]
262
embeddings = embedder.encode_queries(queries) # AbsEmbedder method
263
264
pairs = [("query", "document")]
265
scores = reranker.compute_score(pairs) # AbsReranker method
266
```
267
268
### Multi-Device Processing with Base Classes
269
270
```python
271
from FlagEmbedding import FlagModel
272
273
# Base class handles multi-device distribution automatically
274
embedder = FlagModel(
275
'bge-large-en-v1.5',
276
devices=['cuda:0', 'cuda:1', 'cuda:2'], # Multiple GPUs
277
batch_size=256
278
)
279
280
# Large corpus processing - base class manages device distribution
281
large_corpus = [f"Document {i}" for i in range(50000)]
282
embeddings = embedder.encode_corpus(
283
large_corpus,
284
batch_size=512, # Override default batch size
285
convert_to_numpy=True
286
)
287
288
print(f"Processed {len(large_corpus)} documents across {len(embedder.devices)} devices")
289
```
290
291
### Custom Instruction Handling
292
293
```python
294
from FlagEmbedding import FlagModel, FlagReranker
295
296
# Embedder with custom query instructions
297
embedder = FlagModel(
298
'bge-base-en-v1.5',
299
query_instruction_for_retrieval="Search for: ",
300
query_instruction_format="{}{}" # Base class handles formatting
301
)
302
303
# Reranker with separate query and passage instructions
304
reranker = FlagReranker(
305
'bge-reranker-base',
306
query_instruction_for_rerank="Query: ",
307
passage_instruction_for_rerank="Document: ",
308
query_instruction_format="{}{}",
309
passage_instruction_format="{}{}"
310
)
311
312
# Instructions are automatically applied by base class methods
313
queries = ["machine learning concepts"]
314
embeddings = embedder.encode_queries(queries) # "Search for: machine learning concepts"
315
316
pairs = [("AI research", "Machine learning is a branch of AI")]
317
scores = reranker.compute_score(pairs) # Instructions applied to both query and passage
318
```
319
320
### Flexible Encoding Methods
321
322
```python
323
from FlagEmbedding import FlagModel
324
325
embedder = FlagModel('bge-base-en-v1.5')
326
327
# Different encoding methods for different use cases
328
queries = ["How do neural networks work?"]
329
documents = ["Neural networks are computing systems inspired by biology"]
330
general_text = ["Some general text to embed"]
331
332
# Specialized methods with optimized settings
333
query_embeddings = embedder.encode_queries(queries, max_length=256)
334
doc_embeddings = embedder.encode_corpus(documents, max_length=512)
335
336
# General-purpose method with custom instruction
337
general_embeddings = embedder.encode(
338
general_text,
339
instruction="Encode this text: ",
340
instruction_format="{}{}",
341
max_length=384
342
)
343
```
344
345
### Process Pool Management
346
347
```python
348
from FlagEmbedding import FlagModel
349
import multiprocessing as mp
350
351
def encode_chunk(chunk_data):
352
embedder, text_chunk = chunk_data
353
return embedder.encode_corpus(text_chunk)
354
355
# Initialize embedder
356
embedder = FlagModel('bge-base-en-v1.5')
357
358
# Large dataset to process
359
large_dataset = [f"Document {i}" for i in range(100000)]
360
chunk_size = 1000
361
chunks = [large_dataset[i:i+chunk_size] for i in range(0, len(large_dataset), chunk_size)]
362
363
# Start multi-process pool using base class method
364
pool_info = embedder.start_multi_process_pool(encode_chunk)
365
366
try:
367
# Process chunks in parallel
368
chunk_data = [(embedder, chunk) for chunk in chunks]
369
results = pool_info['pool'].map(encode_chunk, chunk_data)
370
371
# Combine results
372
import numpy as np
373
all_embeddings = np.vstack(results)
374
375
finally:
376
# Clean up pool using base class static method
377
AbsEmbedder.stop_multi_process_pool(pool_info)
378
```
379
380
### Type Checking and Interface Validation
381
382
```python
383
from FlagEmbedding import AbsEmbedder, AbsReranker, FlagAutoModel, FlagAutoReranker
384
385
# Factory methods return base class instances
386
embedder = FlagAutoModel.from_finetuned('bge-base-en-v1.5')
387
reranker = FlagAutoReranker.from_finetuned('bge-reranker-base')
388
389
# Type checking
390
if isinstance(embedder, AbsEmbedder):
391
# Can use all embedder interface methods
392
queries = ["test query"]
393
embeddings = embedder.encode_queries(queries)
394
print(f"Embedding shape: {embeddings.shape}")
395
396
if isinstance(reranker, AbsReranker):
397
# Can use all reranker interface methods
398
pairs = [("query", "document")]
399
scores = reranker.compute_score(pairs)
400
print(f"Relevance score: {scores[0]}")
401
```
402
403
### Error Handling in Base Classes
404
405
```python
406
from FlagEmbedding import FlagModel
407
408
try:
409
# Invalid device specification
410
embedder = FlagModel('bge-base-en-v1.5', devices=['invalid:0'])
411
except RuntimeError as e:
412
print(f"Device error handled by base class: {e}")
413
414
try:
415
# Invalid batch size
416
embedder = FlagModel('bge-base-en-v1.5')
417
embeddings = embedder.encode_queries(["test"], batch_size=-1)
418
except ValueError as e:
419
print(f"Parameter validation by base class: {e}")
420
```
421
422
## Base Class Benefits
423
424
### Consistent Interface
425
- All embedders provide the same methods regardless of underlying architecture
426
- All rerankers follow the same scoring interface
427
- Uniform parameter handling across implementations
428
429
### Multi-Device Support
430
- Automatic workload distribution across multiple GPUs
431
- Consistent performance scaling
432
- Built-in device management and error handling
433
434
### Flexible Configuration
435
- Standardized instruction formatting
436
- Consistent batch processing options
437
- Unified parameter validation
438
439
### Extensibility
440
- Clear interface contracts for new implementations
441
- Abstract methods guide proper implementation
442
- Consistent behavior across model types
443
444
## Types
445
446
```python { .api }
447
from typing import Union, List, Optional, Dict, Any, Callable, Tuple
448
import torch
449
import numpy as np
450
451
# Base class types
452
EmbedderInput = Union[str, List[str]]
453
EmbedderOutput = Union[torch.Tensor, np.ndarray]
454
RerankerInput = List[Tuple[str, str]]
455
RerankerOutput = np.ndarray
456
457
# Configuration types
458
DeviceList = Optional[List[str]]
459
InstructionFormat = str
460
ProcessPoolInfo = Dict[str, Any]
461
ProcessTarget = Callable[[Any], Any]
462
463
# Abstract base class references
464
AbstractEmbedder = AbsEmbedder
465
AbstractReranker = AbsReranker
466
```