or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformers.mdcross-encoder.mdevaluation.mdindex.mdloss-functions.mdsparse-encoder.mdtraining.mdutilities.md

cross-encoder.mddocs/

0

# Cross-Encoder

1

2

Cross-encoders jointly process pairs of sentences, making them ideal for tasks like reranking, textual entailment, and semantic textual similarity where direct comparison between texts is needed.

3

4

## CrossEncoder Class

5

6

### Constructor

7

8

```python

9

CrossEncoder(

10

model_name_or_path: str,

11

num_labels: int | None = None,

12

max_length: int | None = None,

13

activation_fn: Callable | None = None,

14

device: str | None = None,

15

cache_folder: str | None = None,

16

trust_remote_code: bool = False,

17

revision: str | None = None,

18

local_files_only: bool = False,

19

token: bool | str | None = None,

20

model_kwargs: dict | None = None,

21

tokenizer_kwargs: dict | None = None,

22

config_kwargs: dict | None = None,

23

model_card_data: CrossEncoderModelCardData | None = None,

24

backend: Literal["torch", "onnx", "openvino"] = "torch"

25

)

26

```

27

`{ .api }`

28

29

Initialize a CrossEncoder model for scoring sentence pairs.

30

31

**Parameters**:

32

- `model_name_or_path`: A model name from Hugging Face Hub or path to a local model

33

- `num_labels`: Number of labels of the classifier. If 1, regression model that outputs continuous score 0...1. If > 1, outputs several scores for soft-maxed probability scores

34

- `max_length`: Max length for input sequences. Longer sequences will be truncated

35

- `activation_fn`: Callable (like nn.Sigmoid) for the default activation function on top of model.predict()

36

- `device`: Device ("cuda", "cpu", "mps", "npu") that should be used for computation

37

- `cache_folder`: Path to the folder where cached files are stored

38

- `trust_remote_code`: Whether to allow custom models defined on the Hub in their own modeling files

39

- `revision`: The specific model version to use. Can be a branch name, tag name, or commit id

40

- `local_files_only`: Whether to only look at local files (do not try to download the model)

41

- `token`: Hugging Face authentication token to download private models

42

- `model_kwargs`: Additional model configuration parameters to be passed to the Hugging Face Transformers model

43

- `tokenizer_kwargs`: Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer

44

- `config_kwargs`: Additional model configuration parameters to be passed to the Hugging Face Transformers config

45

- `model_card_data`: A model card data object that contains information about the model

46

- `backend`: The backend to use for inference ("torch", "onnx", "openvino")

47

48

### Prediction Methods

49

50

```python

51

def predict(

52

sentences: list[tuple[str, str]] | list[list[str]] | tuple[str, str] | list[str],

53

batch_size: int = 32,

54

show_progress_bar: bool | None = None,

55

activation_fn: Callable | None = None,

56

apply_softmax: bool | None = False,

57

convert_to_numpy: bool = True,

58

convert_to_tensor: bool = False

59

) -> list[torch.Tensor] | np.ndarray | torch.Tensor

60

```

61

`{ .api }`

62

63

Predict scores for sentence pairs.

64

65

**Parameters**:

66

- `sentences`: List of sentence pairs [(Sent1, Sent2), (Sent3, Sent4)] or single sentence pair (Sent1, Sent2)

67

- `batch_size`: Batch size for encoding

68

- `show_progress_bar`: Output progress bar

69

- `activation_fn`: Activation function applied on the logits output of the CrossEncoder

70

- `apply_softmax`: If set to True and model.num_labels > 1, applies softmax on the logits output

71

- `convert_to_numpy`: Whether the output should be a list of numpy vectors

72

- `convert_to_tensor`: Whether the output should be one large tensor

73

74

**Returns**: Prediction scores for each sentence pair

75

76

```python

77

def rank(

78

query: str,

79

documents: list[str],

80

top_k: int | None = None,

81

return_documents: bool = False,

82

batch_size: int = 32,

83

show_progress_bar: bool | None = None,

84

activation_fn: Callable | None = None,

85

apply_softmax=False,

86

convert_to_numpy: bool = True,

87

convert_to_tensor: bool = False

88

) -> list[dict[str, int | float | str]]

89

```

90

`{ .api }`

91

92

Rank documents based on their relevance to a query.

93

94

**Parameters**:

95

- `query`: A single query

96

- `documents`: A list of documents

97

- `top_k`: Return the top-k documents. If None, all documents are returned

98

- `return_documents`: If True, also returns the documents. If False, only returns the indices and scores

99

- `batch_size`: Batch size for encoding

100

- `show_progress_bar`: Output progress bar

101

- `activation_fn`: Activation function applied on the logits output of the CrossEncoder

102

- `apply_softmax`: If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output

103

- `convert_to_numpy`: Convert the output to a numpy matrix

104

- `convert_to_tensor`: Convert the output to a tensor

105

106

**Returns**: List of ranked results with scores and optionally documents

107

108

### Model Management

109

110

```python

111

def save(

112

path: str,

113

*,

114

safe_serialization: bool = True,

115

**kwargs

116

) -> None

117

```

118

`{ .api }`

119

120

Save the cross-encoder model to a directory.

121

122

```python

123

def save_pretrained(

124

path: str,

125

*,

126

safe_serialization: bool = True,

127

**kwargs

128

) -> None

129

```

130

`{ .api }`

131

132

Save model using HuggingFace format.

133

134

```python

135

def push_to_hub(

136

repo_id: str,

137

*,

138

token: str | None = None,

139

private: bool | None = None,

140

safe_serialization: bool = True,

141

commit_message: str | None = None,

142

exist_ok: bool = False,

143

revision: str | None = None,

144

create_pr: bool = False,

145

tags: list[str] | None = None

146

) -> str

147

```

148

`{ .api }`

149

150

Push model to HuggingFace Hub.

151

152

### Properties

153

154

```python

155

@property

156

def device() -> torch.device

157

```

158

`{ .api }`

159

160

Current device of the model.

161

162

```python

163

@property

164

def tokenizer() -> PreTrainedTokenizer

165

```

166

`{ .api }`

167

168

Access to the model's tokenizer.

169

170

```python

171

@property

172

def config() -> PretrainedConfig

173

```

174

`{ .api }`

175

176

Model configuration object.

177

178

## CrossEncoderTrainer

179

180

### Constructor

181

182

```python

183

CrossEncoderTrainer(

184

model: CrossEncoder | None = None,

185

args: CrossEncoderTrainingArguments | None = None,

186

train_dataset: Dataset | None = None,

187

eval_dataset: Dataset | None = None,

188

tokenizer: PreTrainedTokenizer | None = None,

189

data_collator: DataCollator | None = None,

190

compute_metrics: callable | None = None,

191

callbacks: list[TrainerCallback] | None = None,

192

optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),

193

preprocess_logits_for_metrics: callable | None = None

194

)

195

```

196

`{ .api }`

197

198

Trainer for cross-encoder models based on HuggingFace Trainer.

199

200

**Parameters**:

201

- `model`: CrossEncoder model to train

202

- `args`: Training arguments

203

- `train_dataset`: Training dataset

204

- `eval_dataset`: Evaluation dataset

205

- `tokenizer`: Tokenizer (usually auto-detected from model)

206

- `data_collator`: Data collator for batching

207

- `compute_metrics`: Function to compute evaluation metrics

208

- `callbacks`: Training callbacks

209

- `optimizers`: Custom optimizer and scheduler

210

- `preprocess_logits_for_metrics`: Logits preprocessing function

211

212

### Training Methods

213

214

```python

215

def train(

216

resume_from_checkpoint: str | bool | None = None,

217

trial: dict[str, Any] | None = None,

218

ignore_keys_for_eval: list[str] | None = None,

219

**kwargs

220

) -> TrainOutput

221

```

222

`{ .api }`

223

224

Train the cross-encoder model.

225

226

```python

227

def evaluate(

228

eval_dataset: Dataset | None = None,

229

ignore_keys: list[str] | None = None,

230

metric_key_prefix: str = "eval"

231

) -> dict[str, float]

232

```

233

`{ .api }`

234

235

Evaluate the model on the evaluation dataset.

236

237

## CrossEncoderTrainingArguments

238

239

```python

240

class CrossEncoderTrainingArguments(TrainingArguments):

241

def __init__(

242

self,

243

output_dir: str,

244

evaluation_strategy: str | IntervalStrategy = "no",

245

eval_steps: int | None = None,

246

eval_delay: float = 0,

247

logging_dir: str | None = None,

248

logging_strategy: str | IntervalStrategy = "steps",

249

logging_steps: int = 500,

250

save_strategy: str | IntervalStrategy = "steps",

251

save_steps: int = 500,

252

save_total_limit: int | None = None,

253

seed: int = 42,

254

data_seed: int | None = None,

255

jit_mode_eval: bool = False,

256

use_ipex: bool = False,

257

bf16: bool = False,

258

fp16: bool = False,

259

fp16_opt_level: str = "O1",

260

half_precision_backend: str = "auto",

261

bf16_full_eval: bool = False,

262

fp16_full_eval: bool = False,

263

tf32: bool | None = None,

264

local_rank: int = -1,

265

ddp_backend: str | None = None,

266

tpu_num_cores: int | None = None,

267

tpu_metrics_debug: bool = False,

268

debug: str | list[DebugOption] = "",

269

dataloader_drop_last: bool = False,

270

dataloader_num_workers: int = 0,

271

past_index: int = -1,

272

run_name: str | None = None,

273

disable_tqdm: bool | None = None,

274

remove_unused_columns: bool = True,

275

label_names: list[str] | None = None,

276

load_best_model_at_end: bool = False,

277

ignore_data_skip: bool = False,

278

fsdp: str | list[str] = "",

279

fsdp_min_num_params: int = 0,

280

fsdp_config: dict[str, Any] | None = None,

281

fsdp_transformer_layer_cls_to_wrap: str | None = None,

282

deepspeed: str | None = None,

283

label_smoothing_factor: float = 0.0,

284

optim: str | OptimizerNames = "adamw_torch",

285

optim_args: str | None = None,

286

adafactor: bool = False,

287

group_by_length: bool = False,

288

length_column_name: str | None = "length",

289

report_to: str | list[str] | None = None,

290

ddp_find_unused_parameters: bool | None = None,

291

ddp_bucket_cap_mb: int | None = None,

292

ddp_broadcast_buffers: bool | None = None,

293

dataloader_pin_memory: bool = True,

294

skip_memory_metrics: bool = True,

295

use_legacy_prediction_loop: bool = False,

296

push_to_hub: bool = False,

297

resume_from_checkpoint: str | None = None,

298

hub_model_id: str | None = None,

299

hub_strategy: str | HubStrategy = "every_save",

300

hub_token: str | None = None,

301

hub_private_repo: bool = False,

302

hub_always_push: bool = False,

303

gradient_checkpointing: bool = False,

304

include_inputs_for_metrics: bool = False,

305

auto_find_batch_size: bool = False,

306

full_determinism: bool = False,

307

torchdynamo: str | None = None,

308

ray_scope: str | None = "last",

309

ddp_timeout: int = 1800,

310

torch_compile: bool = False,

311

torch_compile_backend: str | None = None,

312

torch_compile_mode: str | None = None,

313

dispatch_batches: bool | None = None,

314

split_batches: bool | None = None,

315

include_tokens_per_second: bool = False,

316

**kwargs

317

)

318

```

319

`{ .api }`

320

321

Training arguments for cross-encoder training, extending HuggingFace TrainingArguments.

322

323

## CrossEncoderModelCardData

324

325

```python

326

class CrossEncoderModelCardData:

327

def __init__(

328

self,

329

language: str | list[str] | None = None,

330

license: str | None = None,

331

tags: str | list[str] | None = None,

332

model_name: str | None = None,

333

model_id: str | None = None,

334

eval_results: list[EvalResult] | None = None,

335

train_datasets: str | list[str] | None = None,

336

eval_datasets: str | list[str] | None = None

337

)

338

```

339

`{ .api }`

340

341

Data class for generating model cards for cross-encoder models.

342

343

**Parameters**:

344

- `language`: Language(s) supported by the model

345

- `license`: Model license

346

- `tags`: Tags for categorizing the model

347

- `model_name`: Human-readable model name

348

- `model_id`: Model identifier

349

- `eval_results`: Evaluation results to include

350

- `train_datasets`: Training datasets used

351

- `eval_datasets`: Evaluation datasets used

352

353

## Usage Examples

354

355

### Basic Cross-Encoder Usage

356

357

```python

358

from sentence_transformers import CrossEncoder

359

360

# Load pre-trained cross-encoder

361

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

362

363

# Score sentence pairs

364

pairs = [

365

('How many people live in Berlin?', 'Berlin has a population of 3,520,031'),

366

('How many people live in Berlin?', 'The weather in Berlin is nice'),

367

('What is the capital of France?', 'Paris is the capital of France')

368

]

369

370

scores = cross_encoder.predict(pairs)

371

print("Relevance scores:", scores)

372

373

# Apply softmax for probabilities

374

probs = cross_encoder.predict(pairs, apply_softmax=True)

375

print("Relevance probabilities:", probs)

376

```

377

378

### Document Ranking

379

380

```python

381

query = "How to learn machine learning?"

382

documents = [

383

"Machine learning is a subset of artificial intelligence",

384

"Start with basic statistics and linear algebra",

385

"Python is a popular programming language",

386

"Practice with real datasets and projects",

387

"Understanding algorithms is crucial for ML success"

388

]

389

390

# Rank documents by relevance

391

results = cross_encoder.rank(query, documents, top_k=3)

392

393

for result in results:

394

print(f"Score: {result['score']:.4f}")

395

print(f"Document: {result['corpus_id']}")

396

if 'text' in result:

397

print(f"Text: {result['text']}")

398

print()

399

```

400

401

### Binary Classification

402

403

```python

404

# For binary classification tasks

405

cross_encoder = CrossEncoder('cross-encoder/nli-deberta-v3-base')

406

407

# Natural Language Inference pairs

408

nli_pairs = [

409

("A man is eating pizza", "A man is eating food"), # Entailment

410

("A woman is reading a book", "A woman is cooking"), # Contradiction

411

("It's raining outside", "The weather is bad") # Neutral

412

]

413

414

scores = cross_encoder.predict(nli_pairs, apply_softmax=True)

415

# Returns probabilities for [contradiction, entailment, neutral]

416

417

for pair, score in zip(nli_pairs, scores):

418

prediction = ["contradiction", "entailment", "neutral"][score.argmax()]

419

confidence = score.max()

420

print(f"Premise: {pair[0]}")

421

print(f"Hypothesis: {pair[1]}")

422

print(f"Prediction: {prediction} (confidence: {confidence:.4f})")

423

print()

424

```

425

426

### Training Cross-Encoder

427

428

```python

429

from sentence_transformers import CrossEncoder, CrossEncoderTrainer, CrossEncoderTrainingArguments

430

from datasets import Dataset

431

import torch

432

433

# Create training data

434

train_data = [

435

{"sentence1": "The cat sits on the mat", "sentence2": "A feline rests on a rug", "label": 1},

436

{"sentence1": "I love pizza", "sentence2": "Dogs are great pets", "label": 0},

437

{"sentence1": "Machine learning is AI", "sentence2": "ML is a subset of artificial intelligence", "label": 1}

438

]

439

440

# Convert to dataset

441

train_dataset = Dataset.from_list(train_data)

442

443

# Initialize cross-encoder

444

model = CrossEncoder('distilbert-base-uncased', num_labels=2)

445

446

# Training arguments

447

args = CrossEncoderTrainingArguments(

448

output_dir='./cross-encoder-output',

449

num_train_epochs=3,

450

per_device_train_batch_size=16,

451

logging_steps=10,

452

save_steps=100,

453

eval_steps=100,

454

evaluation_strategy="steps",

455

save_total_limit=2,

456

load_best_model_at_end=True,

457

)

458

459

# Create trainer

460

trainer = CrossEncoderTrainer(

461

model=model,

462

args=args,

463

train_dataset=train_dataset,

464

compute_metrics=lambda eval_pred: {

465

'accuracy': (eval_pred.predictions.argmax(-1) == eval_pred.label_ids).mean()

466

}

467

)

468

469

# Train model

470

trainer.train()

471

472

# Save trained model

473

model.save('./my-cross-encoder')

474

```

475

476

### Advanced Usage with Custom Activation

477

478

```python

479

import torch.nn as nn

480

481

# Load model with custom activation

482

cross_encoder = CrossEncoder(

483

'cross-encoder/ms-marco-MiniLM-L-6-v2',

484

default_activation_function=nn.Sigmoid()

485

)

486

487

# Use custom activation in prediction

488

scores = cross_encoder.predict(

489

pairs,

490

activation_fct=nn.Tanh() # Override default activation

491

)

492

493

# Batch prediction with progress bar

494

large_pairs = [("query " + str(i), "document " + str(i)) for i in range(1000)]

495

scores = cross_encoder.predict(

496

large_pairs,

497

batch_size=64,

498

show_progress_bar=True,

499

num_workers=4

500

)

501

```

502

503

### Model Card Generation

504

505

```python

506

from sentence_transformers import CrossEncoderModelCardData

507

508

# Create model card data

509

model_card_data = CrossEncoderModelCardData(

510

language=['en'],

511

license='apache-2.0',

512

tags=['sentence-transformers', 'cross-encoder', 'reranking'],

513

model_name='My Custom Cross-Encoder',

514

train_datasets=['ms-marco'],

515

eval_datasets=['trec-dl-2019']

516

)

517

518

# Save model with model card

519

cross_encoder.save('./my-model', model_card_data=model_card_data)

520

```

521

522

## Best Practices

523

524

1. **Task Selection**: Use cross-encoders for tasks requiring direct comparison between text pairs

525

2. **Performance**: Cross-encoders are more accurate but slower than bi-encoders

526

3. **Batch Size**: Use larger batch sizes for better GPU utilization

527

4. **Activation Functions**: Choose appropriate activations based on your task

528

5. **Model Selection**: Select models pre-trained on similar tasks when possible

529

6. **Evaluation**: Always evaluate on held-out test sets for reliable performance metrics