0
# Core Model Classes
1
2
Scikit-learn compatible model classes that provide the main interfaces for CatBoost gradient boosting. These classes handle classification, regression, and ranking tasks with comprehensive parameter configuration and training options.
3
4
## Capabilities
5
6
### CatBoost Base Class
7
8
The foundational model class providing core gradient boosting functionality with training, prediction, feature importance, and model persistence methods.
9
10
```python { .api }
11
class CatBoost:
12
def __init__(self, params=None):
13
"""
14
Initialize CatBoost model with parameters.
15
16
Parameters:
17
- params (dict): Model parameters
18
"""
19
20
def fit(self, X, y=None, cat_features=None, text_features=None,
21
embedding_features=None, pairs=None, graph=None, sample_weight=None,
22
group_id=None, group_weight=None, subgroup_id=None, pairs_weight=None,
23
baseline=None, use_best_model=None, eval_set=None, verbose=None,
24
logging_level=None, plot=False, plot_file=None, early_stopping_rounds=None,
25
save_snapshot=None, snapshot_file=None, snapshot_interval=600,
26
init_model=None):
27
"""
28
Train the CatBoost model.
29
30
Parameters:
31
- X: Training data (Pool, list, numpy.ndarray, pandas.DataFrame, pandas.Series, FeaturesData, or file path)
32
- y: Target values (array-like)
33
- cat_features: Categorical feature column indices or names
34
- text_features: Text feature column indices or names
35
- embedding_features: Embedding feature column indices or names
36
- pairs: Pairs for ranking (array-like)
37
- graph: Graph for collaborative filtering
38
- sample_weight: Sample weights
39
- group_id: Group identifiers for ranking
40
- group_weight: Group weights
41
- subgroup_id: Subgroup identifiers
42
- pairs_weight: Pairs weights
43
- baseline: Baseline values
44
- use_best_model: Use best model from evaluation
45
- eval_set: Evaluation datasets [(X, y), ...]
46
- verbose: Verbosity level
47
- logging_level: Logging level
48
- plot: Enable plotting
49
- plot_file: Plot output file
50
- early_stopping_rounds: Early stopping rounds
51
- save_snapshot: Save training snapshots
52
- snapshot_file: Snapshot file name
53
- snapshot_interval: Snapshot interval in seconds
54
- init_model: Initial model for continued training
55
56
Returns:
57
Self
58
"""
59
60
def predict(self, data, prediction_type='RawFormulaVal', ntree_start=0,
61
ntree_end=0, thread_count=-1, verbose=None, task_type='CPU'):
62
"""
63
Make predictions on data.
64
65
Parameters:
66
- data: Input data (Pool or array-like)
67
- prediction_type: Type of prediction ('RawFormulaVal', 'Class', 'Probability')
68
- ntree_start: Start tree index
69
- ntree_end: End tree index (0 means use all trees)
70
- thread_count: Number of threads
71
- verbose: Verbosity level
72
- task_type: Task type ('CPU' or 'GPU')
73
74
Returns:
75
numpy.ndarray: Predictions
76
"""
77
78
def get_feature_importance(self, data=None, type='FeatureImportance',
79
prettified=False, thread_count=-1, shap_mode=None,
80
interaction_indices=None, shap_calc_type='Regular',
81
model_output_type='RawFormulaVal', **kwargs):
82
"""
83
Calculate feature importance.
84
85
Parameters:
86
- data: Data for importance calculation (Pool or array-like)
87
- type: Importance type (EFstrType enum value)
88
- prettified: Return prettified DataFrame
89
- thread_count: Number of threads
90
- shap_mode: SHAP calculation mode
91
- interaction_indices: Feature indices for interaction
92
- shap_calc_type: SHAP calculation type
93
- model_output_type: Model output type
94
95
Returns:
96
numpy.ndarray or pandas.DataFrame: Feature importance values
97
"""
98
99
def get_object_importance(self, pool, train_pool, top_size=-1,
100
type='Average', update_method='SinglePoint',
101
importance_values_sign='All', thread_count=-1):
102
"""
103
Calculate object importance (leaf influence).
104
105
Parameters:
106
- pool: Pool for importance calculation
107
- train_pool: Training pool
108
- top_size: Number of top important objects (-1 for all)
109
- type: Importance type ('Average', 'PerObject')
110
- update_method: Update method ('SinglePoint', 'TopKLeaves', 'AllPoints')
111
- importance_values_sign: Values sign ('All', 'Positive', 'Negative')
112
- thread_count: Number of threads
113
114
Returns:
115
numpy.ndarray: Object importance values
116
"""
117
118
def save_model(self, fname, format='cbm', export_parameters=None, pool=None):
119
"""
120
Save model to file.
121
122
Parameters:
123
- fname: File name or file-like object
124
- format: Model format ('cbm', 'json', 'onnx', 'pmml', 'python', 'cpp')
125
- export_parameters: Export parameters for specific formats
126
- pool: Pool for ONNX export
127
"""
128
129
def load_model(self, fname=None, format='cbm', stream=None, blob=None):
130
"""
131
Load model from file.
132
133
Parameters:
134
- fname: File name
135
- format: Model format
136
- stream: Input stream
137
- blob: Model blob data
138
"""
139
140
def copy(self):
141
"""Create a copy of the model."""
142
143
def get_params(self, deep=True):
144
"""Get model parameters."""
145
146
def set_params(self, **params):
147
"""Set model parameters."""
148
```
149
150
### CatBoostClassifier
151
152
Scikit-learn compatible classifier with binary and multi-class classification support, including probability predictions and class-specific methods.
153
154
```python { .api }
155
class CatBoostClassifier(CatBoost):
156
def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,
157
model_size_reg=0.5, rsm=1.0, loss_function='Logloss',
158
border_count=254, feature_border_type='GreedyLogSum',
159
per_float_feature_quantization=None, input_borders=None,
160
output_borders=None, fold_permutation_block=1,
161
od_pval=0.001, od_wait=20, od_type='IncToDec', nan_mode='Min',
162
counter_calc_method='SkipTest', leaf_estimation_iterations=None,
163
leaf_estimation_method='Newton', thread_count=-1,
164
random_seed=None, use_best_model=None, best_model_min_trees=1,
165
verbose=None, silent=None, logging_level=None, metric_period=1,
166
ctr_leaf_count_limit=None, store_all_simple_ctr=None,
167
max_ctr_complexity=4, has_time=False, allow_const_label=None,
168
target_border=None, classes_count=None, class_weights=None,
169
auto_class_weights=None, class_names=None, one_hot_max_size=None,
170
random_strength=1.0, name='experiment', ignored_features=None,
171
train_dir=None, custom_loss=None, custom_metric=None,
172
eval_metric=None, bagging_temperature=1.0, save_snapshot=None,
173
snapshot_file=None, snapshot_interval=600, fold_len_multiplier=2.0,
174
used_ram_limit='1gb', gpu_ram_part=0.95, pinned_memory_size='104857600',
175
allow_writing_files=True, final_ctr_computation_mode='Default',
176
approx_on_full_history=False, boosting_type=None, simple_ctr=None,
177
combinations_ctr=None, per_feature_ctr=None, ctr_description=None,
178
ctr_target_border_count=None, task_type=None, device_config=None,
179
devices=None, bootstrap_type=None, subsample=None,
180
sampling_unit='Object', dev_score_calc_obj_block_size=None,
181
max_depth=None, grow_policy='SymmetricTree', min_data_in_leaf=1,
182
max_leaves=31, num_boost_round=None, feature_weights=None,
183
penalties_coefficient=1.0, first_feature_use_penalties=None,
184
model_shrink_rate=None, model_shrink_mode=None, langevin=False,
185
diffusion_temperature=10000.0, posterior_sampling=False,
186
boost_from_average=None, text_features=None,
187
tokenizers=None, dictionaries=None, feature_calcers=None,
188
text_processing=None, embedding_features=None, **kwargs):
189
"""
190
Initialize CatBoost classifier.
191
192
Key Parameters:
193
- iterations (int): Number of boosting iterations (default: 500)
194
- learning_rate (float): Learning rate (default: auto-calculated)
195
- depth (int): Tree depth (default: 6)
196
- l2_leaf_reg (float): L2 regularization coefficient (default: 3.0)
197
- loss_function (str): Loss function ('Logloss', 'CrossEntropy', 'MultiClass', 'MultiClassOneVsAll')
198
- class_weights (list/dict): Class weights for imbalanced datasets
199
- auto_class_weights (str): Automatic class weight calculation ('Balanced', 'SqrtBalanced')
200
- eval_metric (str): Evaluation metric ('Logloss', 'AUC', 'Accuracy', 'Precision', 'Recall', 'F1')
201
- early_stopping_rounds (int): Early stopping rounds
202
- task_type (str): Task type ('CPU' or 'GPU')
203
- verbose (bool/int): Verbosity level
204
"""
205
206
def fit(self, X, y, cat_features=None, text_features=None,
207
embedding_features=None, graph=None, sample_weight=None,
208
baseline=None, use_best_model=None, eval_set=None, verbose=None,
209
logging_level=None, plot=False, plot_file=None,
210
early_stopping_rounds=None, save_snapshot=None, snapshot_file=None,
211
snapshot_interval=600, init_model=None):
212
"""
213
Train the classifier.
214
215
Parameters: Same as CatBoost.fit()
216
217
Returns:
218
Self
219
"""
220
221
def predict(self, data, prediction_type='Class', ntree_start=0, ntree_end=0,
222
thread_count=-1, verbose=None, task_type='CPU'):
223
"""
224
Predict class labels.
225
226
Parameters:
227
- data: Input data
228
- prediction_type: 'Class' for class labels, 'RawFormulaVal' for raw values
229
230
Returns:
231
numpy.ndarray: Predicted class labels
232
"""
233
234
def predict_proba(self, X, ntree_start=0, ntree_end=0, thread_count=-1,
235
verbose=None, task_type='CPU'):
236
"""
237
Predict class probabilities.
238
239
Parameters:
240
- X: Input data
241
- ntree_start: Start tree index
242
- ntree_end: End tree index
243
- thread_count: Number of threads
244
- verbose: Verbosity level
245
- task_type: Task type
246
247
Returns:
248
numpy.ndarray: Class probabilities (n_samples, n_classes)
249
"""
250
251
def predict_log_proba(self, data, ntree_start=0, ntree_end=0, thread_count=-1,
252
verbose=None, task_type='CPU'):
253
"""
254
Predict logarithm of class probabilities.
255
256
Returns:
257
numpy.ndarray: Log probabilities
258
"""
259
260
def staged_predict(self, data, prediction_type='Class', ntree_start=0,
261
ntree_end=0, eval_period=1, thread_count=-1, verbose=None):
262
"""
263
Predict for each stage of boosting.
264
265
Returns:
266
generator: Predictions for each boosting iteration
267
"""
268
269
def staged_predict_proba(self, data, ntree_start=0, ntree_end=0, eval_period=1,
270
thread_count=-1, verbose=None):
271
"""
272
Predict probabilities for each stage of boosting.
273
274
Returns:
275
generator: Probabilities for each boosting iteration
276
"""
277
278
@property
279
def classes_(self):
280
"""Get class labels."""
281
282
@property
283
def feature_importances_(self):
284
"""Get feature importances (scikit-learn compatibility)."""
285
```
286
287
### CatBoostRegressor
288
289
Scikit-learn compatible regressor supporting various loss functions for different regression tasks including standard regression, quantile regression, and survival analysis.
290
291
```python { .api }
292
class CatBoostRegressor(CatBoost):
293
def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,
294
model_size_reg=0.5, rsm=1.0, loss_function='RMSE',
295
border_count=128, feature_border_type='GreedyLogSum',
296
# ... (same parameters as CatBoostClassifier except loss-specific ones)
297
**kwargs):
298
"""
299
Initialize CatBoost regressor.
300
301
Key Parameters:
302
- loss_function (str): Loss function ('RMSE', 'MAE', 'Quantile:alpha=0.5',
303
'LogLinQuantile:alpha=0.5', 'Poisson', 'MAPE',
304
'Lq:q=2', 'SurvivalAft:dist=Normal;scale=1.0')
305
- eval_metric (str): Evaluation metric ('RMSE', 'MAE', 'R2', 'MSLE', 'MedianAbsoluteError')
306
"""
307
308
def fit(self, X, y, **kwargs):
309
"""Train the regressor. Same interface as CatBoost.fit()."""
310
311
def predict(self, data, **kwargs):
312
"""
313
Predict target values.
314
315
Returns:
316
numpy.ndarray: Predicted values
317
"""
318
319
def staged_predict(self, data, **kwargs):
320
"""
321
Predict for each stage of boosting.
322
323
Returns:
324
generator: Predictions for each boosting iteration
325
"""
326
327
@property
328
def feature_importances_(self):
329
"""Get feature importances (scikit-learn compatibility)."""
330
```
331
332
### CatBoostRanker
333
334
Scikit-learn compatible ranker for learning-to-rank tasks with support for various ranking loss functions and group-based evaluation.
335
336
```python { .api }
337
class CatBoostRanker(CatBoost):
338
def __init__(self, iterations=500, learning_rate=None, depth=6, l2_leaf_reg=3.0,
339
model_size_reg=0.5, rsm=1.0, loss_function='YetiRank',
340
# ... (same parameters as other CatBoost classes)
341
**kwargs):
342
"""
343
Initialize CatBoost ranker.
344
345
Key Parameters:
346
- loss_function (str): Ranking loss function ('YetiRank', 'YetiRankPairwise',
347
'StochasticFilter', 'StochasticRank', 'QueryCrossEntropy',
348
'QueryRMSE', 'GroupQuantile:alpha=0.5', 'QuerySoftMax',
349
'PairLogit', 'PairLogitPairwise')
350
- eval_metric (str): Ranking evaluation metric ('NDCG', 'DCG', 'MAP', 'MRR', 'ERR')
351
"""
352
353
def fit(self, X, y, group_id=None, **kwargs):
354
"""
355
Train the ranker.
356
357
Parameters: Same as CatBoost.fit() with group_id being important for ranking
358
- group_id: Group identifiers for ranking (required for most ranking tasks)
359
"""
360
361
def predict(self, data, **kwargs):
362
"""
363
Predict ranking scores.
364
365
Returns:
366
numpy.ndarray: Ranking scores
367
"""
368
369
def staged_predict(self, data, **kwargs):
370
"""
371
Predict ranking scores for each stage of boosting.
372
373
Returns:
374
generator: Ranking scores for each boosting iteration
375
"""
376
377
@property
378
def feature_importances_(self):
379
"""Get feature importances (scikit-learn compatibility)."""
380
```
381
382
## Model Conversion Functions
383
384
```python { .api }
385
def to_classifier(model):
386
"""
387
Convert CatBoost model to classifier.
388
389
Parameters:
390
- model: CatBoost model
391
392
Returns:
393
CatBoostClassifier: Converted classifier
394
"""
395
396
def to_regressor(model):
397
"""
398
Convert CatBoost model to regressor.
399
400
Parameters:
401
- model: CatBoost model
402
403
Returns:
404
CatBoostRegressor: Converted regressor
405
"""
406
407
def to_ranker(model):
408
"""
409
Convert CatBoost model to ranker.
410
411
Parameters:
412
- model: CatBoost model
413
414
Returns:
415
CatBoostRanker: Converted ranker
416
"""
417
```