0
# Cross-Encoder
1
2
Cross-encoders jointly process pairs of sentences, making them ideal for tasks like reranking, textual entailment, and semantic textual similarity where direct comparison between texts is needed.
3
4
## CrossEncoder Class
5
6
### Constructor
7
8
```python
9
CrossEncoder(
10
model_name_or_path: str,
11
num_labels: int | None = None,
12
max_length: int | None = None,
13
activation_fn: Callable | None = None,
14
device: str | None = None,
15
cache_folder: str | None = None,
16
trust_remote_code: bool = False,
17
revision: str | None = None,
18
local_files_only: bool = False,
19
token: bool | str | None = None,
20
model_kwargs: dict | None = None,
21
tokenizer_kwargs: dict | None = None,
22
config_kwargs: dict | None = None,
23
model_card_data: CrossEncoderModelCardData | None = None,
24
backend: Literal["torch", "onnx", "openvino"] = "torch"
25
)
26
```
27
`{ .api }`
28
29
Initialize a CrossEncoder model for scoring sentence pairs.
30
31
**Parameters**:
32
- `model_name_or_path`: A model name from Hugging Face Hub or path to a local model
33
- `num_labels`: Number of labels of the classifier. If 1, regression model that outputs continuous score 0...1. If > 1, outputs several scores for soft-maxed probability scores
34
- `max_length`: Max length for input sequences. Longer sequences will be truncated
35
- `activation_fn`: Callable (like nn.Sigmoid) for the default activation function on top of model.predict()
36
- `device`: Device ("cuda", "cpu", "mps", "npu") that should be used for computation
37
- `cache_folder`: Path to the folder where cached files are stored
38
- `trust_remote_code`: Whether to allow custom models defined on the Hub in their own modeling files
39
- `revision`: The specific model version to use. Can be a branch name, tag name, or commit id
40
- `local_files_only`: Whether to only look at local files (do not try to download the model)
41
- `token`: Hugging Face authentication token to download private models
42
- `model_kwargs`: Additional model configuration parameters to be passed to the Hugging Face Transformers model
43
- `tokenizer_kwargs`: Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer
44
- `config_kwargs`: Additional model configuration parameters to be passed to the Hugging Face Transformers config
45
- `model_card_data`: A model card data object that contains information about the model
46
- `backend`: The backend to use for inference ("torch", "onnx", "openvino")
47
48
### Prediction Methods
49
50
```python
51
def predict(
52
sentences: list[tuple[str, str]] | list[list[str]] | tuple[str, str] | list[str],
53
batch_size: int = 32,
54
show_progress_bar: bool | None = None,
55
activation_fn: Callable | None = None,
56
apply_softmax: bool | None = False,
57
convert_to_numpy: bool = True,
58
convert_to_tensor: bool = False
59
) -> list[torch.Tensor] | np.ndarray | torch.Tensor
60
```
61
`{ .api }`
62
63
Predict scores for sentence pairs.
64
65
**Parameters**:
66
- `sentences`: List of sentence pairs [(Sent1, Sent2), (Sent3, Sent4)] or single sentence pair (Sent1, Sent2)
67
- `batch_size`: Batch size for encoding
68
- `show_progress_bar`: Output progress bar
69
- `activation_fn`: Activation function applied on the logits output of the CrossEncoder
70
- `apply_softmax`: If set to True and model.num_labels > 1, applies softmax on the logits output
71
- `convert_to_numpy`: Whether the output should be a list of numpy vectors
72
- `convert_to_tensor`: Whether the output should be one large tensor
73
74
**Returns**: Prediction scores for each sentence pair
75
76
```python
77
def rank(
78
query: str,
79
documents: list[str],
80
top_k: int | None = None,
81
return_documents: bool = False,
82
batch_size: int = 32,
83
show_progress_bar: bool | None = None,
84
activation_fn: Callable | None = None,
85
apply_softmax=False,
86
convert_to_numpy: bool = True,
87
convert_to_tensor: bool = False
88
) -> list[dict[str, int | float | str]]
89
```
90
`{ .api }`
91
92
Rank documents based on their relevance to a query.
93
94
**Parameters**:
95
- `query`: A single query
96
- `documents`: A list of documents
97
- `top_k`: Return the top-k documents. If None, all documents are returned
98
- `return_documents`: If True, also returns the documents. If False, only returns the indices and scores
99
- `batch_size`: Batch size for encoding
100
- `show_progress_bar`: Output progress bar
101
- `activation_fn`: Activation function applied on the logits output of the CrossEncoder
102
- `apply_softmax`: If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output
103
- `convert_to_numpy`: Convert the output to a numpy matrix
104
- `convert_to_tensor`: Convert the output to a tensor
105
106
**Returns**: List of ranked results with scores and optionally documents
107
108
### Model Management
109
110
```python
111
def save(
112
path: str,
113
*,
114
safe_serialization: bool = True,
115
**kwargs
116
) -> None
117
```
118
`{ .api }`
119
120
Save the cross-encoder model to a directory.
121
122
```python
123
def save_pretrained(
124
path: str,
125
*,
126
safe_serialization: bool = True,
127
**kwargs
128
) -> None
129
```
130
`{ .api }`
131
132
Save model using HuggingFace format.
133
134
```python
135
def push_to_hub(
136
repo_id: str,
137
*,
138
token: str | None = None,
139
private: bool | None = None,
140
safe_serialization: bool = True,
141
commit_message: str | None = None,
142
exist_ok: bool = False,
143
revision: str | None = None,
144
create_pr: bool = False,
145
tags: list[str] | None = None
146
) -> str
147
```
148
`{ .api }`
149
150
Push model to HuggingFace Hub.
151
152
### Properties
153
154
```python
155
@property
156
def device() -> torch.device
157
```
158
`{ .api }`
159
160
Current device of the model.
161
162
```python
163
@property
164
def tokenizer() -> PreTrainedTokenizer
165
```
166
`{ .api }`
167
168
Access to the model's tokenizer.
169
170
```python
171
@property
172
def config() -> PretrainedConfig
173
```
174
`{ .api }`
175
176
Model configuration object.
177
178
## CrossEncoderTrainer
179
180
### Constructor
181
182
```python
183
CrossEncoderTrainer(
184
model: CrossEncoder | None = None,
185
args: CrossEncoderTrainingArguments | None = None,
186
train_dataset: Dataset | None = None,
187
eval_dataset: Dataset | None = None,
188
tokenizer: PreTrainedTokenizer | None = None,
189
data_collator: DataCollator | None = None,
190
compute_metrics: callable | None = None,
191
callbacks: list[TrainerCallback] | None = None,
192
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
193
preprocess_logits_for_metrics: callable | None = None
194
)
195
```
196
`{ .api }`
197
198
Trainer for cross-encoder models based on HuggingFace Trainer.
199
200
**Parameters**:
201
- `model`: CrossEncoder model to train
202
- `args`: Training arguments
203
- `train_dataset`: Training dataset
204
- `eval_dataset`: Evaluation dataset
205
- `tokenizer`: Tokenizer (usually auto-detected from model)
206
- `data_collator`: Data collator for batching
207
- `compute_metrics`: Function to compute evaluation metrics
208
- `callbacks`: Training callbacks
209
- `optimizers`: Custom optimizer and scheduler
210
- `preprocess_logits_for_metrics`: Logits preprocessing function
211
212
### Training Methods
213
214
```python
215
def train(
216
resume_from_checkpoint: str | bool | None = None,
217
trial: dict[str, Any] | None = None,
218
ignore_keys_for_eval: list[str] | None = None,
219
**kwargs
220
) -> TrainOutput
221
```
222
`{ .api }`
223
224
Train the cross-encoder model.
225
226
```python
227
def evaluate(
228
eval_dataset: Dataset | None = None,
229
ignore_keys: list[str] | None = None,
230
metric_key_prefix: str = "eval"
231
) -> dict[str, float]
232
```
233
`{ .api }`
234
235
Evaluate the model on the evaluation dataset.
236
237
## CrossEncoderTrainingArguments
238
239
```python
240
class CrossEncoderTrainingArguments(TrainingArguments):
241
def __init__(
242
self,
243
output_dir: str,
244
evaluation_strategy: str | IntervalStrategy = "no",
245
eval_steps: int | None = None,
246
eval_delay: float = 0,
247
logging_dir: str | None = None,
248
logging_strategy: str | IntervalStrategy = "steps",
249
logging_steps: int = 500,
250
save_strategy: str | IntervalStrategy = "steps",
251
save_steps: int = 500,
252
save_total_limit: int | None = None,
253
seed: int = 42,
254
data_seed: int | None = None,
255
jit_mode_eval: bool = False,
256
use_ipex: bool = False,
257
bf16: bool = False,
258
fp16: bool = False,
259
fp16_opt_level: str = "O1",
260
half_precision_backend: str = "auto",
261
bf16_full_eval: bool = False,
262
fp16_full_eval: bool = False,
263
tf32: bool | None = None,
264
local_rank: int = -1,
265
ddp_backend: str | None = None,
266
tpu_num_cores: int | None = None,
267
tpu_metrics_debug: bool = False,
268
debug: str | list[DebugOption] = "",
269
dataloader_drop_last: bool = False,
270
dataloader_num_workers: int = 0,
271
past_index: int = -1,
272
run_name: str | None = None,
273
disable_tqdm: bool | None = None,
274
remove_unused_columns: bool = True,
275
label_names: list[str] | None = None,
276
load_best_model_at_end: bool = False,
277
ignore_data_skip: bool = False,
278
fsdp: str | list[str] = "",
279
fsdp_min_num_params: int = 0,
280
fsdp_config: dict[str, Any] | None = None,
281
fsdp_transformer_layer_cls_to_wrap: str | None = None,
282
deepspeed: str | None = None,
283
label_smoothing_factor: float = 0.0,
284
optim: str | OptimizerNames = "adamw_torch",
285
optim_args: str | None = None,
286
adafactor: bool = False,
287
group_by_length: bool = False,
288
length_column_name: str | None = "length",
289
report_to: str | list[str] | None = None,
290
ddp_find_unused_parameters: bool | None = None,
291
ddp_bucket_cap_mb: int | None = None,
292
ddp_broadcast_buffers: bool | None = None,
293
dataloader_pin_memory: bool = True,
294
skip_memory_metrics: bool = True,
295
use_legacy_prediction_loop: bool = False,
296
push_to_hub: bool = False,
297
resume_from_checkpoint: str | None = None,
298
hub_model_id: str | None = None,
299
hub_strategy: str | HubStrategy = "every_save",
300
hub_token: str | None = None,
301
hub_private_repo: bool = False,
302
hub_always_push: bool = False,
303
gradient_checkpointing: bool = False,
304
include_inputs_for_metrics: bool = False,
305
auto_find_batch_size: bool = False,
306
full_determinism: bool = False,
307
torchdynamo: str | None = None,
308
ray_scope: str | None = "last",
309
ddp_timeout: int = 1800,
310
torch_compile: bool = False,
311
torch_compile_backend: str | None = None,
312
torch_compile_mode: str | None = None,
313
dispatch_batches: bool | None = None,
314
split_batches: bool | None = None,
315
include_tokens_per_second: bool = False,
316
**kwargs
317
)
318
```
319
`{ .api }`
320
321
Training arguments for cross-encoder training, extending HuggingFace TrainingArguments.
322
323
## CrossEncoderModelCardData
324
325
```python
326
class CrossEncoderModelCardData:
327
def __init__(
328
self,
329
language: str | list[str] | None = None,
330
license: str | None = None,
331
tags: str | list[str] | None = None,
332
model_name: str | None = None,
333
model_id: str | None = None,
334
eval_results: list[EvalResult] | None = None,
335
train_datasets: str | list[str] | None = None,
336
eval_datasets: str | list[str] | None = None
337
)
338
```
339
`{ .api }`
340
341
Data class for generating model cards for cross-encoder models.
342
343
**Parameters**:
344
- `language`: Language(s) supported by the model
345
- `license`: Model license
346
- `tags`: Tags for categorizing the model
347
- `model_name`: Human-readable model name
348
- `model_id`: Model identifier
349
- `eval_results`: Evaluation results to include
350
- `train_datasets`: Training datasets used
351
- `eval_datasets`: Evaluation datasets used
352
353
## Usage Examples
354
355
### Basic Cross-Encoder Usage
356
357
```python
358
from sentence_transformers import CrossEncoder
359
360
# Load pre-trained cross-encoder
361
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
362
363
# Score sentence pairs
364
pairs = [
365
('How many people live in Berlin?', 'Berlin has a population of 3,520,031'),
366
('How many people live in Berlin?', 'The weather in Berlin is nice'),
367
('What is the capital of France?', 'Paris is the capital of France')
368
]
369
370
scores = cross_encoder.predict(pairs)
371
print("Relevance scores:", scores)
372
373
# Apply softmax for probabilities
374
probs = cross_encoder.predict(pairs, apply_softmax=True)
375
print("Relevance probabilities:", probs)
376
```
377
378
### Document Ranking
379
380
```python
381
query = "How to learn machine learning?"
382
documents = [
383
"Machine learning is a subset of artificial intelligence",
384
"Start with basic statistics and linear algebra",
385
"Python is a popular programming language",
386
"Practice with real datasets and projects",
387
"Understanding algorithms is crucial for ML success"
388
]
389
390
# Rank documents by relevance
391
results = cross_encoder.rank(query, documents, top_k=3)
392
393
for result in results:
394
print(f"Score: {result['score']:.4f}")
395
print(f"Document: {result['corpus_id']}")
396
if 'text' in result:
397
print(f"Text: {result['text']}")
398
print()
399
```
400
401
### Binary Classification
402
403
```python
404
# For binary classification tasks
405
cross_encoder = CrossEncoder('cross-encoder/nli-deberta-v3-base')
406
407
# Natural Language Inference pairs
408
nli_pairs = [
409
("A man is eating pizza", "A man is eating food"), # Entailment
410
("A woman is reading a book", "A woman is cooking"), # Contradiction
411
("It's raining outside", "The weather is bad") # Neutral
412
]
413
414
scores = cross_encoder.predict(nli_pairs, apply_softmax=True)
415
# Returns probabilities for [contradiction, entailment, neutral]
416
417
for pair, score in zip(nli_pairs, scores):
418
prediction = ["contradiction", "entailment", "neutral"][score.argmax()]
419
confidence = score.max()
420
print(f"Premise: {pair[0]}")
421
print(f"Hypothesis: {pair[1]}")
422
print(f"Prediction: {prediction} (confidence: {confidence:.4f})")
423
print()
424
```
425
426
### Training Cross-Encoder
427
428
```python
429
from sentence_transformers import CrossEncoder, CrossEncoderTrainer, CrossEncoderTrainingArguments
430
from datasets import Dataset
431
import torch
432
433
# Create training data
434
train_data = [
435
{"sentence1": "The cat sits on the mat", "sentence2": "A feline rests on a rug", "label": 1},
436
{"sentence1": "I love pizza", "sentence2": "Dogs are great pets", "label": 0},
437
{"sentence1": "Machine learning is AI", "sentence2": "ML is a subset of artificial intelligence", "label": 1}
438
]
439
440
# Convert to dataset
441
train_dataset = Dataset.from_list(train_data)
442
443
# Initialize cross-encoder
444
model = CrossEncoder('distilbert-base-uncased', num_labels=2)
445
446
# Training arguments
447
args = CrossEncoderTrainingArguments(
448
output_dir='./cross-encoder-output',
449
num_train_epochs=3,
450
per_device_train_batch_size=16,
451
logging_steps=10,
452
save_steps=100,
453
eval_steps=100,
454
evaluation_strategy="steps",
455
save_total_limit=2,
456
load_best_model_at_end=True,
457
)
458
459
# Create trainer
460
trainer = CrossEncoderTrainer(
461
model=model,
462
args=args,
463
train_dataset=train_dataset,
464
compute_metrics=lambda eval_pred: {
465
'accuracy': (eval_pred.predictions.argmax(-1) == eval_pred.label_ids).mean()
466
}
467
)
468
469
# Train model
470
trainer.train()
471
472
# Save trained model
473
model.save('./my-cross-encoder')
474
```
475
476
### Advanced Usage with Custom Activation
477
478
```python
479
import torch.nn as nn
480
481
# Load model with custom activation
482
cross_encoder = CrossEncoder(
483
'cross-encoder/ms-marco-MiniLM-L-6-v2',
484
default_activation_function=nn.Sigmoid()
485
)
486
487
# Use custom activation in prediction
488
scores = cross_encoder.predict(
489
pairs,
490
activation_fct=nn.Tanh() # Override default activation
491
)
492
493
# Batch prediction with progress bar
494
large_pairs = [("query " + str(i), "document " + str(i)) for i in range(1000)]
495
scores = cross_encoder.predict(
496
large_pairs,
497
batch_size=64,
498
show_progress_bar=True,
499
num_workers=4
500
)
501
```
502
503
### Model Card Generation
504
505
```python
506
from sentence_transformers import CrossEncoderModelCardData
507
508
# Create model card data
509
model_card_data = CrossEncoderModelCardData(
510
language=['en'],
511
license='apache-2.0',
512
tags=['sentence-transformers', 'cross-encoder', 'reranking'],
513
model_name='My Custom Cross-Encoder',
514
train_datasets=['ms-marco'],
515
eval_datasets=['trec-dl-2019']
516
)
517
518
# Save model with model card
519
cross_encoder.save('./my-model', model_card_data=model_card_data)
520
```
521
522
## Best Practices
523
524
1. **Task Selection**: Use cross-encoders for tasks requiring direct comparison between text pairs
525
2. **Performance**: Cross-encoders are more accurate but slower than bi-encoders
526
3. **Batch Size**: Use larger batch sizes for better GPU utilization
527
4. **Activation Functions**: Choose appropriate activations based on your task
528
5. **Model Selection**: Select models pre-trained on similar tasks when possible
529
6. **Evaluation**: Always evaluate on held-out test sets for reliable performance metrics