0
# Training and Evaluation
1
2
CatBoost provides comprehensive training and evaluation capabilities including cross-validation, hyperparameter tuning, early stopping, and Gaussian process sampling for uncertainty estimation. These functions support both Pool objects and raw data formats.
3
4
## Capabilities
5
6
### Training Functions
7
8
Direct training functions that provide more control over the training process compared to the model classes.
9
10
```python { .api }
11
def train(pool, params=None, dtrain=None, logging_level=None, verbose=None,
12
iterations=None, num_boost_round=None, eval_set=None, plot=False,
13
save_snapshot=None, snapshot_file=None, snapshot_interval=600,
14
metric_period=1, verbose_eval=None, early_stopping_rounds=None,
15
use_best_model=None, best_model_min_trees=1, log_cout=None,
16
log_cerr=None):
17
"""
18
Train a CatBoost model using specified parameters.
19
20
Parameters:
21
- pool: Training data (Pool object)
22
- params: Training parameters (dict)
23
- dtrain: Deprecated, use pool instead
24
- logging_level: Logging level ('Silent', 'Verbose', 'Info', 'Debug')
25
- verbose: Verbosity level (bool or int)
26
- iterations: Number of boosting iterations (int)
27
- num_boost_round: Alias for iterations
28
- eval_set: Evaluation datasets (list of Pool objects)
29
- plot: Enable plotting during training (bool)
30
- save_snapshot: Save training snapshots (bool)
31
- snapshot_file: Snapshot file name (string)
32
- snapshot_interval: Snapshot interval in seconds (int)
33
- metric_period: Metric calculation period (int)
34
- verbose_eval: Verbose evaluation period (int)
35
- early_stopping_rounds: Early stopping rounds (int)
36
- use_best_model: Use best model from validation (bool)
37
- best_model_min_trees: Minimum trees for best model (int)
38
- log_cout: Output stream for logging
39
- log_cerr: Error stream for logging
40
41
Returns:
42
CatBoost: Trained CatBoost model
43
"""
44
```
45
46
### Cross-Validation
47
48
Robust cross-validation with stratification, custom folds, and comprehensive evaluation metrics.
49
50
```python { .api }
51
def cv(pool, params=None, dtrain=None, iterations=None, num_boost_round=None,
52
fold_count=3, inverted=False, shuffle=True, partition_random_seed=0,
53
stratified=None, train_dir=None, verbose=None, logging_level=None,
54
metric_period=1, verbose_eval=None, plot=False, save_snapshot=None,
55
snapshot_file=None, snapshot_interval=600, folds=None,
56
early_stopping_rounds=None, as_pandas=True, return_models=False,
57
log_cout=None, log_cerr=None, type='Classical'):
58
"""
59
Perform cross-validation with CatBoost.
60
61
Parameters:
62
- pool: Training data (Pool object)
63
- params: Model parameters (dict)
64
- dtrain: Deprecated, use pool instead
65
- iterations: Number of boosting iterations (int)
66
- num_boost_round: Alias for iterations
67
- fold_count: Number of cross-validation folds (int, default: 3)
68
- inverted: Use inverted folds (bool)
69
- shuffle: Shuffle data before folding (bool)
70
- partition_random_seed: Random seed for data partitioning (int)
71
- stratified: Use stratified cross-validation (bool, auto for classification)
72
- train_dir: Directory for training artifacts (string)
73
- verbose: Verbosity level (bool or int)
74
- logging_level: Logging level (string)
75
- metric_period: Metric calculation period (int)
76
- verbose_eval: Verbose evaluation period (int)
77
- plot: Enable plotting (bool)
78
- save_snapshot: Save snapshots (bool)
79
- snapshot_file: Snapshot file name (string)
80
- snapshot_interval: Snapshot interval in seconds (int)
81
- folds: Custom fold indices (array-like)
82
- early_stopping_rounds: Early stopping rounds (int)
83
- as_pandas: Return results as pandas DataFrame (bool)
84
- return_models: Return trained fold models (bool)
85
- log_cout: Output stream for logging
86
- log_cerr: Error stream for logging
87
- type: Cross-validation type ('Classical', 'Inverted', 'TimeSeries')
88
89
Returns:
90
pandas.DataFrame or dict: Cross-validation results with metrics per iteration
91
"""
92
```
93
94
### Gaussian Process Sampling
95
96
Advanced uncertainty estimation using Gaussian process sampling for regression tasks.
97
98
```python { .api }
99
def sample_gaussian_process(X, y, eval_set=None, cat_features=None,
100
text_features=None, embedding_features=None,
101
sample_weight=None, group_id=None, group_weight=None,
102
subgroup_id=None, pairs=None, pairs_weight=None,
103
baseline=None, n_samples=100, random_seed=None,
104
logging_level='Verbose', verbose=None, plot=False,
105
model_params=None, gp_params=None):
106
"""
107
Sample from Gaussian process for uncertainty estimation.
108
109
Parameters:
110
- X: Input features (array-like or Pool)
111
- y: Target values (array-like)
112
- eval_set: Evaluation datasets (list of tuples)
113
- cat_features: Categorical feature indices (list)
114
- text_features: Text feature indices (list)
115
- embedding_features: Embedding feature indices (list)
116
- sample_weight: Sample weights (array-like)
117
- group_id: Group identifiers (array-like)
118
- group_weight: Group weights (array-like)
119
- subgroup_id: Subgroup identifiers (array-like)
120
- pairs: Pairs for ranking (array-like)
121
- pairs_weight: Pairs weights (array-like)
122
- baseline: Baseline values (array-like)
123
- n_samples: Number of GP samples (int, default: 100)
124
- random_seed: Random seed (int)
125
- logging_level: Logging level (string)
126
- verbose: Verbosity level (bool or int)
127
- plot: Enable plotting (bool)
128
- model_params: CatBoost model parameters (dict)
129
- gp_params: Gaussian process parameters (dict)
130
131
Returns:
132
tuple: (predictions_mean, predictions_std, gp_samples)
133
- predictions_mean: Mean predictions (numpy.ndarray)
134
- predictions_std: Standard deviation of predictions (numpy.ndarray)
135
- gp_samples: Individual GP samples (numpy.ndarray)
136
"""
137
```
138
139
### Model Ensemble Functions
140
141
Functions for combining multiple CatBoost models into ensembles.
142
143
```python { .api }
144
def sum_models(models, weights=None, ctr_merge_policy='IntersectingCountersAverage'):
145
"""
146
Create ensemble by summing multiple CatBoost models.
147
148
Parameters:
149
- models: List of CatBoost models to combine (list)
150
- weights: Weights for each model (list of float, optional)
151
- ctr_merge_policy: CTR merge policy ('IntersectingCountersAverage', 'CountersSum')
152
153
Returns:
154
CatBoost: Combined model
155
"""
156
157
def _have_equal_features(models):
158
"""
159
Check if models have equal feature sets.
160
161
Parameters:
162
- models: List of CatBoost models (list)
163
164
Returns:
165
bool: True if all models have equal features
166
"""
167
```
168
169
### Hyperparameter Optimization
170
171
Grid search and randomized search utilities for hyperparameter optimization.
172
173
```python { .api }
174
class CatBoostSearchCV:
175
"""
176
Hyperparameter search using cross-validation.
177
178
Similar to scikit-learn's GridSearchCV but optimized for CatBoost.
179
"""
180
181
def __init__(self, estimator, param_grid, scoring=None, cv=3,
182
refit=True, verbose=0, n_jobs=1, return_train_score=False):
183
"""
184
Initialize hyperparameter search.
185
186
Parameters:
187
- estimator: CatBoost estimator
188
- param_grid: Parameter grid to search (dict)
189
- scoring: Scoring metric (string or callable)
190
- cv: Cross-validation strategy (int or cv splitter)
191
- refit: Refit best estimator (bool)
192
- verbose: Verbosity level (int)
193
- n_jobs: Number of parallel jobs (int)
194
- return_train_score: Return training scores (bool)
195
"""
196
197
def fit(self, X, y=None, **fit_params):
198
"""Fit the search."""
199
200
def predict(self, X):
201
"""Predict using the best estimator."""
202
203
@property
204
def best_estimator_(self):
205
"""Best estimator found."""
206
207
@property
208
def best_params_(self):
209
"""Best parameters found."""
210
211
@property
212
def best_score_(self):
213
"""Best cross-validation score."""
214
```
215
216
## Training Examples
217
218
### Basic Training with Cross-Validation
219
220
```python
221
from catboost import Pool, train, cv
222
import pandas as pd
223
224
# Prepare data
225
df = pd.read_csv('train.csv')
226
X = df.drop('target', axis=1)
227
y = df['target']
228
229
# Create Pool
230
train_pool = Pool(
231
data=X,
232
label=y,
233
cat_features=['category_column']
234
)
235
236
# Define parameters
237
params = {
238
'iterations': 1000,
239
'learning_rate': 0.1,
240
'depth': 6,
241
'loss_function': 'RMSE',
242
'eval_metric': 'RMSE',
243
'verbose': 100,
244
'early_stopping_rounds': 50
245
}
246
247
# Perform cross-validation
248
cv_results = cv(
249
pool=train_pool,
250
params=params,
251
fold_count=5,
252
stratified=False,
253
shuffle=True,
254
partition_random_seed=42
255
)
256
257
print(f"Best CV score: {cv_results['test-RMSE-mean'].min():.4f}")
258
259
# Train final model
260
model = train(
261
pool=train_pool,
262
params=params
263
)
264
```
265
266
### Advanced Training with Validation Set
267
268
```python
269
from catboost import Pool, train
270
from sklearn.model_selection import train_test_split
271
272
# Split data
273
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
274
275
# Create pools
276
train_pool = Pool(X_train, y_train, cat_features=['category'])
277
val_pool = Pool(X_val, y_val, cat_features=['category'])
278
279
# Train with validation
280
model = train(
281
pool=train_pool,
282
eval_set=[val_pool],
283
params={
284
'iterations': 1000,
285
'learning_rate': 0.1,
286
'depth': 6,
287
'loss_function': 'RMSE',
288
'eval_metric': 'RMSE',
289
'early_stopping_rounds': 100,
290
'use_best_model': True,
291
'verbose': 100
292
},
293
plot=True # Enable training plots
294
)
295
```
296
297
### Gaussian Process Uncertainty Estimation
298
299
```python
300
from catboost import sample_gaussian_process
301
import numpy as np
302
303
# Sample from Gaussian process
304
mean_pred, std_pred, gp_samples = sample_gaussian_process(
305
X=X_train,
306
y=y_train,
307
eval_set=[(X_val, y_val)],
308
cat_features=['category'],
309
n_samples=100,
310
model_params={
311
'iterations': 500,
312
'learning_rate': 0.1,
313
'depth': 4
314
},
315
random_seed=42
316
)
317
318
# Get prediction intervals
319
lower_bound = mean_pred - 1.96 * std_pred
320
upper_bound = mean_pred + 1.96 * std_pred
321
322
print(f"Mean prediction: {mean_pred.mean():.4f}")
323
print(f"Prediction uncertainty: {std_pred.mean():.4f}")
324
```
325
326
### Model Ensemble
327
328
```python
329
from catboost import CatBoostRegressor, sum_models
330
331
# Train multiple models with different parameters
332
models = []
333
for depth in [4, 6, 8]:
334
model = CatBoostRegressor(
335
iterations=500,
336
depth=depth,
337
learning_rate=0.1,
338
verbose=False
339
)
340
model.fit(X_train, y_train)
341
models.append(model)
342
343
# Create ensemble
344
ensemble = sum_models(
345
models=models,
346
weights=[0.4, 0.4, 0.2] # Weight models differently
347
)
348
349
# Make predictions with ensemble
350
ensemble_pred = ensemble.predict(X_test)
351
```