or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-features.mdcore-models.mddata-handling.mddatasets.mdevaluation.mdfeature-analysis.mdindex.mdmetrics.mdtraining-evaluation.mdutilities.mdvisualization.md

core-models.mddocs/

0

# Core Model Classes

1

2

Scikit-learn compatible model classes that provide the main interfaces for CatBoost gradient boosting. These classes handle classification, regression, and ranking tasks with comprehensive parameter configuration and training options.

3

4

## Capabilities

5

6

### CatBoost Base Class

7

8

The foundational model class providing core gradient boosting functionality with training, prediction, feature importance, and model persistence methods.

9

10

```python { .api }

11

class CatBoost:

12

def __init__(self, params=None):

13

"""

14

Initialize CatBoost model with parameters.

15

16

Parameters:

17

- params (dict): Model parameters

18

"""

19

20

def fit(self, X, y=None, cat_features=None, text_features=None,

21

embedding_features=None, pairs=None, graph=None, sample_weight=None,

22

group_id=None, group_weight=None, subgroup_id=None, pairs_weight=None,

23

baseline=None, use_best_model=None, eval_set=None, verbose=None,

24

logging_level=None, plot=False, plot_file=None, early_stopping_rounds=None,

25

save_snapshot=None, snapshot_file=None, snapshot_interval=600,

26

init_model=None):

27

"""

28

Train the CatBoost model.

29

30

Parameters:

31

- X: Training data (Pool, list, numpy.ndarray, pandas.DataFrame, pandas.Series, FeaturesData, or file path)

32

- y: Target values (array-like)

33

- cat_features: Categorical feature column indices or names

34

- text_features: Text feature column indices or names

35

- embedding_features: Embedding feature column indices or names

36

- pairs: Pairs for ranking (array-like)

37

- graph: Graph for collaborative filtering

38

- sample_weight: Sample weights

39

- group_id: Group identifiers for ranking

40

- group_weight: Group weights

41

- subgroup_id: Subgroup identifiers

42

- pairs_weight: Pairs weights

43

- baseline: Baseline values

44

- use_best_model: Use best model from evaluation

45

- eval_set: Evaluation datasets [(X, y), ...]

46

- verbose: Verbosity level

47

- logging_level: Logging level

48

- plot: Enable plotting

49

- plot_file: Plot output file

50

- early_stopping_rounds: Early stopping rounds

51

- save_snapshot: Save training snapshots

52

- snapshot_file: Snapshot file name

53

- snapshot_interval: Snapshot interval in seconds

54

- init_model: Initial model for continued training

55

56

Returns:

57

Self

58

"""

59

60

def predict(self, data, prediction_type='RawFormulaVal', ntree_start=0,

61

ntree_end=0, thread_count=-1, verbose=None, task_type='CPU'):

62

"""

63

Make predictions on data.

64

65

Parameters:

66

- data: Input data (Pool or array-like)

67

- prediction_type: Type of prediction ('RawFormulaVal', 'Class', 'Probability')

68

- ntree_start: Start tree index

69

- ntree_end: End tree index (0 means use all trees)

70

- thread_count: Number of threads

71

- verbose: Verbosity level

72

- task_type: Task type ('CPU' or 'GPU')

73

74

Returns:

75

numpy.ndarray: Predictions

76

"""

77

78

def get_feature_importance(self, data=None, type='FeatureImportance',

79

prettified=False, thread_count=-1, shap_mode=None,

80

interaction_indices=None, shap_calc_type='Regular',

81

model_output_type='RawFormulaVal', **kwargs):

82

"""

83

Calculate feature importance.

84

85

Parameters:

86

- data: Data for importance calculation (Pool or array-like)

87

- type: Importance type (EFstrType enum value)

88

- prettified: Return prettified DataFrame

89

- thread_count: Number of threads

90

- shap_mode: SHAP calculation mode

91

- interaction_indices: Feature indices for interaction

92

- shap_calc_type: SHAP calculation type

93

- model_output_type: Model output type

94

95

Returns:

96

numpy.ndarray or pandas.DataFrame: Feature importance values

97

"""

98

99

def get_object_importance(self, pool, train_pool, top_size=-1,

100

type='Average', update_method='SinglePoint',

101

importance_values_sign='All', thread_count=-1):

102

"""

103

Calculate object importance (leaf influence).

104

105

Parameters:

106

- pool: Pool for importance calculation

107

- train_pool: Training pool

108

- top_size: Number of top important objects (-1 for all)

109

- type: Importance type ('Average', 'PerObject')

110

- update_method: Update method ('SinglePoint', 'TopKLeaves', 'AllPoints')

111

- importance_values_sign: Values sign ('All', 'Positive', 'Negative')

112

- thread_count: Number of threads

113

114

Returns:

115

numpy.ndarray: Object importance values

116

"""

117

118

def save_model(self, fname, format='cbm', export_parameters=None, pool=None):

119

"""

120

Save model to file.

121

122

Parameters:

123

- fname: File name or file-like object

124

- format: Model format ('cbm', 'json', 'onnx', 'pmml', 'python', 'cpp')

125

- export_parameters: Export parameters for specific formats

126

- pool: Pool for ONNX export

127

"""

128

129

def load_model(self, fname=None, format='cbm', stream=None, blob=None):

130

"""

131

Load model from file.

132

133

Parameters:

134

- fname: File name

135

- format: Model format

136

- stream: Input stream

137

- blob: Model blob data

138

"""

139

140

def copy(self):

141

"""Create a copy of the model."""

142

143

def get_params(self, deep=True):

144

"""Get model parameters."""

145

146

def set_params(self, **params):

147

"""Set model parameters."""

148

```

149

150

### CatBoostClassifier

151

152

Scikit-learn compatible classifier with binary and multi-class classification support, including probability predictions and class-specific methods.

153

154

```python { .api }

155

class CatBoostClassifier(CatBoost):

156

def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,

157

model_size_reg=0.5, rsm=1.0, loss_function='Logloss',

158

border_count=254, feature_border_type='GreedyLogSum',

159

per_float_feature_quantization=None, input_borders=None,

160

output_borders=None, fold_permutation_block=1,

161

od_pval=0.001, od_wait=20, od_type='IncToDec', nan_mode='Min',

162

counter_calc_method='SkipTest', leaf_estimation_iterations=None,

163

leaf_estimation_method='Newton', thread_count=-1,

164

random_seed=None, use_best_model=None, best_model_min_trees=1,

165

verbose=None, silent=None, logging_level=None, metric_period=1,

166

ctr_leaf_count_limit=None, store_all_simple_ctr=None,

167

max_ctr_complexity=4, has_time=False, allow_const_label=None,

168

target_border=None, classes_count=None, class_weights=None,

169

auto_class_weights=None, class_names=None, one_hot_max_size=None,

170

random_strength=1.0, name='experiment', ignored_features=None,

171

train_dir=None, custom_loss=None, custom_metric=None,

172

eval_metric=None, bagging_temperature=1.0, save_snapshot=None,

173

snapshot_file=None, snapshot_interval=600, fold_len_multiplier=2.0,

174

used_ram_limit='1gb', gpu_ram_part=0.95, pinned_memory_size='104857600',

175

allow_writing_files=True, final_ctr_computation_mode='Default',

176

approx_on_full_history=False, boosting_type=None, simple_ctr=None,

177

combinations_ctr=None, per_feature_ctr=None, ctr_description=None,

178

ctr_target_border_count=None, task_type=None, device_config=None,

179

devices=None, bootstrap_type=None, subsample=None,

180

sampling_unit='Object', dev_score_calc_obj_block_size=None,

181

max_depth=None, grow_policy='SymmetricTree', min_data_in_leaf=1,

182

max_leaves=31, num_boost_round=None, feature_weights=None,

183

penalties_coefficient=1.0, first_feature_use_penalties=None,

184

model_shrink_rate=None, model_shrink_mode=None, langevin=False,

185

diffusion_temperature=10000.0, posterior_sampling=False,

186

boost_from_average=None, text_features=None,

187

tokenizers=None, dictionaries=None, feature_calcers=None,

188

text_processing=None, embedding_features=None, **kwargs):

189

"""

190

Initialize CatBoost classifier.

191

192

Key Parameters:

193

- iterations (int): Number of boosting iterations (default: 500)

194

- learning_rate (float): Learning rate (default: auto-calculated)

195

- depth (int): Tree depth (default: 6)

196

- l2_leaf_reg (float): L2 regularization coefficient (default: 3.0)

197

- loss_function (str): Loss function ('Logloss', 'CrossEntropy', 'MultiClass', 'MultiClassOneVsAll')

198

- class_weights (list/dict): Class weights for imbalanced datasets

199

- auto_class_weights (str): Automatic class weight calculation ('Balanced', 'SqrtBalanced')

200

- eval_metric (str): Evaluation metric ('Logloss', 'AUC', 'Accuracy', 'Precision', 'Recall', 'F1')

201

- early_stopping_rounds (int): Early stopping rounds

202

- task_type (str): Task type ('CPU' or 'GPU')

203

- verbose (bool/int): Verbosity level

204

"""

205

206

def fit(self, X, y, cat_features=None, text_features=None,

207

embedding_features=None, graph=None, sample_weight=None,

208

baseline=None, use_best_model=None, eval_set=None, verbose=None,

209

logging_level=None, plot=False, plot_file=None,

210

early_stopping_rounds=None, save_snapshot=None, snapshot_file=None,

211

snapshot_interval=600, init_model=None):

212

"""

213

Train the classifier.

214

215

Parameters: Same as CatBoost.fit()

216

217

Returns:

218

Self

219

"""

220

221

def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0,

222

thread_count=-1, verbose=None, task_type='CPU'):

223

"""

224

Predict class labels.

225

226

Parameters:

227

- data: Input data

228

- prediction_type: 'Class' for class labels, 'RawFormulaVal' for raw values

229

230

Returns:

231

numpy.ndarray: Predicted class labels

232

"""

233

234

def predict_proba(self, X, ntree_start=0, ntree_end=0, thread_count=-1,

235

verbose=None, task_type='CPU'):

236

"""

237

Predict class probabilities.

238

239

Parameters:

240

- X: Input data

241

- ntree_start: Start tree index

242

- ntree_end: End tree index

243

- thread_count: Number of threads

244

- verbose: Verbosity level

245

- task_type: Task type

246

247

Returns:

248

numpy.ndarray: Class probabilities (n_samples, n_classes)

249

"""

250

251

def predict_log_proba(self, data, ntree_start=0, ntree_end=0, thread_count=-1,

252

verbose=None, task_type='CPU'):

253

"""

254

Predict logarithm of class probabilities.

255

256

Returns:

257

numpy.ndarray: Log probabilities

258

"""

259

260

def staged_predict(self, data, prediction_type='Class', ntree_start=0,

261

ntree_end=0, eval_period=1, thread_count=-1, verbose=None):

262

"""

263

Predict for each stage of boosting.

264

265

Returns:

266

generator: Predictions for each boosting iteration

267

"""

268

269

def staged_predict_proba(self, data, ntree_start=0, ntree_end=0, eval_period=1,

270

thread_count=-1, verbose=None):

271

"""

272

Predict probabilities for each stage of boosting.

273

274

Returns:

275

generator: Probabilities for each boosting iteration

276

"""

277

278

@property

279

def classes_(self):

280

"""Get class labels."""

281

282

@property

283

def feature_importances_(self):

284

"""Get feature importances (scikit-learn compatibility)."""

285

```

286

287

### CatBoostRegressor

288

289

Scikit-learn compatible regressor supporting various loss functions for different regression tasks including standard regression, quantile regression, and survival analysis.

290

291

```python { .api }

292

class CatBoostRegressor(CatBoost):

293

def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,

294

model_size_reg=0.5, rsm=1.0, loss_function='RMSE',

295

border_count=128, feature_border_type='GreedyLogSum',

296

# ... (same parameters as CatBoostClassifier except loss-specific ones)

297

**kwargs):

298

"""

299

Initialize CatBoost regressor.

300

301

Key Parameters:

302

- loss_function (str): Loss function ('RMSE', 'MAE', 'Quantile:alpha=0.5',

303

'LogLinQuantile:alpha=0.5', 'Poisson', 'MAPE',

304

'Lq:q=2', 'SurvivalAft:dist=Normal;scale=1.0')

305

- eval_metric (str): Evaluation metric ('RMSE', 'MAE', 'R2', 'MSLE', 'MedianAbsoluteError')

306

"""

307

308

def fit(self, X, y, **kwargs):

309

"""Train the regressor. Same interface as CatBoost.fit()."""

310

311

def predict(self, data, **kwargs):

312

"""

313

Predict target values.

314

315

Returns:

316

numpy.ndarray: Predicted values

317

"""

318

319

def staged_predict(self, data, **kwargs):

320

"""

321

Predict for each stage of boosting.

322

323

Returns:

324

generator: Predictions for each boosting iteration

325

"""

326

327

@property

328

def feature_importances_(self):

329

"""Get feature importances (scikit-learn compatibility)."""

330

```

331

332

### CatBoostRanker

333

334

Scikit-learn compatible ranker for learning-to-rank tasks with support for various ranking loss functions and group-based evaluation.

335

336

```python { .api }

337

class CatBoostRanker(CatBoost):

338

def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,

339

model_size_reg=0.5, rsm=1.0, loss_function='YetiRank',

340

# ... (same parameters as other CatBoost classes)

341

**kwargs):

342

"""

343

Initialize CatBoost ranker.

344

345

Key Parameters:

346

- loss_function (str): Ranking loss function ('YetiRank', 'YetiRankPairwise',

347

'StochasticFilter', 'StochasticRank', 'QueryCrossEntropy',

348

'QueryRMSE', 'GroupQuantile:alpha=0.5', 'QuerySoftMax',

349

'PairLogit', 'PairLogitPairwise')

350

- eval_metric (str): Ranking evaluation metric ('NDCG', 'DCG', 'MAP', 'MRR', 'ERR')

351

"""

352

353

def fit(self, X, y, group_id=None, **kwargs):

354

"""

355

Train the ranker.

356

357

Parameters: Same as CatBoost.fit() with group_id being important for ranking

358

- group_id: Group identifiers for ranking (required for most ranking tasks)

359

"""

360

361

def predict(self, data, **kwargs):

362

"""

363

Predict ranking scores.

364

365

Returns:

366

numpy.ndarray: Ranking scores

367

"""

368

369

def staged_predict(self, data, **kwargs):

370

"""

371

Predict ranking scores for each stage of boosting.

372

373

Returns:

374

generator: Ranking scores for each boosting iteration

375

"""

376

377

@property

378

def feature_importances_(self):

379

"""Get feature importances (scikit-learn compatibility)."""

380

```

381

382

## Model Conversion Functions

383

384

```python { .api }

385

def to_classifier(model):

386

"""

387

Convert CatBoost model to classifier.

388

389

Parameters:

390

- model: CatBoost model

391

392

Returns:

393

CatBoostClassifier: Converted classifier

394

"""

395

396

def to_regressor(model):

397

"""

398

Convert CatBoost model to regressor.

399

400

Parameters:

401

- model: CatBoost model

402

403

Returns:

404

CatBoostRegressor: Converted regressor

405

"""

406

407

def to_ranker(model):

408

"""

409

Convert CatBoost model to ranker.

410

411

Parameters:

412

- model: CatBoost model

413

414

Returns:

415

CatBoostRanker: Converted ranker

416

"""

417

```