0
# Training and Validation
1
2
Comprehensive training and validation capabilities with support for custom datasets, hyperparameter tuning, distributed training, and model optimization techniques.
3
4
## Capabilities
5
6
### Model Training
7
8
Train YOLO models on custom datasets with extensive configuration options and built-in optimization techniques.
9
10
```python { .api }
11
def train(self, data=None, epochs=100, imgsz=640, batch=16, **kwargs) -> dict:
12
"""
13
Train the model on a dataset.
14
15
Parameters:
16
- data (str | Path): Path to dataset YAML file
17
- epochs (int): Number of training epochs (default: 100)
18
- imgsz (int): Image size for training (default: 640)
19
- batch (int): Batch size (default: 16)
20
- lr0 (float): Initial learning rate (default: 0.01)
21
- lrf (float): Final learning rate factor (default: 0.01)
22
- momentum (float): SGD momentum (default: 0.937)
23
- weight_decay (float): Optimizer weight decay (default: 0.0005)
24
- warmup_epochs (float): Warmup epochs (default: 3.0)
25
- warmup_momentum (float): Warmup momentum (default: 0.8)
26
- warmup_bias_lr (float): Warmup bias learning rate (default: 0.1)
27
- box (float): Box loss gain (default: 7.5)
28
- cls (float): Classification loss gain (default: 0.5)
29
- dfl (float): Distribution focal loss gain (default: 1.5)
30
- pose (float): Pose loss gain (default: 12.0)
31
- kobj (float): Keypoint objectness loss gain (default: 2.0)
32
- dropout (float): Use dropout regularization (default: 0.0)
33
- val (bool): Validate during training (default: True)
34
- save (bool): Save training checkpoints (default: True)
35
- save_period (int): Save checkpoint every x epochs (default: -1)
36
- cache (str): Cache images for faster training ('ram', 'disk', False)
37
- device (str): Device to train on ('cpu', '0', '0,1', etc.)
38
- workers (int): Number of worker threads (default: 8)
39
- project (str): Project name (default: 'runs/train')
40
- name (str): Experiment name (default: 'exp')
41
- exist_ok (bool): Overwrite existing experiment (default: False)
42
- pretrained (bool | str): Use pretrained model (default: True)
43
- optimizer (str): Optimizer ('SGD', 'Adam', 'AdamW', 'RMSProp')
44
- verbose (bool): Verbose output (default: False)
45
- seed (int): Random seed (default: 0)
46
- deterministic (bool): Deterministic mode (default: True)
47
- single_cls (bool): Train as single-class dataset (default: False)
48
- rect (bool): Rectangular training (default: False)
49
- cos_lr (bool): Cosine learning rate scheduler (default: False)
50
- close_mosaic (int): Close mosaic augmentation at this epoch (default: 10)
51
- resume (bool | str): Resume training from checkpoint (default: False)
52
- amp (bool): Automatic Mixed Precision training (default: True)
53
- fraction (float): Dataset fraction to train on (default: 1.0)
54
- profile (bool): Profile ONNX and TensorRT speeds (default: False)
55
- freeze (int | List[int]): Freeze layers (default: None)
56
57
Returns:
58
dict: Training results and metrics
59
"""
60
```
61
62
**Usage Examples:**
63
64
```python
65
from ultralytics import YOLO
66
67
# Load a model
68
model = YOLO("yolo11n.pt")
69
70
# Basic training
71
results = model.train(data="coco8.yaml", epochs=100, imgsz=640)
72
73
# Advanced training configuration
74
results = model.train(
75
data="custom_dataset.yaml",
76
epochs=300,
77
imgsz=1280,
78
batch=8,
79
lr0=0.001,
80
optimizer='AdamW',
81
augment=True,
82
mixup=0.1,
83
copy_paste=0.1,
84
device='0,1', # Multi-GPU training
85
workers=16,
86
project='my_project',
87
name='custom_experiment'
88
)
89
90
# Resume training
91
results = model.train(resume=True)
92
93
# Train with custom callbacks
94
def on_epoch_end(trainer):
95
print(f"Epoch {trainer.epoch} completed")
96
97
results = model.train(
98
data="dataset.yaml",
99
epochs=100,
100
callbacks={'on_epoch_end': on_epoch_end}
101
)
102
```
103
104
### Model Validation
105
106
Validate trained models on test datasets to evaluate performance metrics.
107
108
```python { .api }
109
def val(self, data=None, split='val', imgsz=640, batch=16, **kwargs) -> dict:
110
"""
111
Validate the model on a dataset.
112
113
Parameters:
114
- data (str | Path): Path to dataset YAML file
115
- split (str): Dataset split to validate on ('val', 'test')
116
- imgsz (int): Image size for validation (default: 640)
117
- batch (int): Batch size (default: 16)
118
- conf (float): Confidence threshold (default: 0.001)
119
- iou (float): IoU threshold for NMS (default: 0.6)
120
- max_det (int): Maximum detections per image (default: 300)
121
- half (bool): Use FP16 inference (default: True)
122
- device (str): Device to run on ('cpu', '0', '0,1', etc.)
123
- dnn (bool): Use OpenCV DNN for ONNX inference (default: False)
124
- plots (bool): Save prediction plots (default: False)
125
- save_txt (bool): Save results as txt files (default: False)
126
- save_conf (bool): Include confidence in txt files (default: False)
127
- save_json (bool): Save results as JSON (default: False)
128
- project (str): Project name (default: 'runs/val')
129
- name (str): Experiment name (default: 'exp')
130
- exist_ok (bool): Overwrite existing experiment (default: False)
131
- verbose (bool): Verbose output (default: True)
132
- workers (int): Number of worker threads (default: 8)
133
134
Returns:
135
dict: Validation metrics including mAP, precision, recall
136
"""
137
```
138
139
**Usage Examples:**
140
141
```python
142
# Basic validation
143
metrics = model.val()
144
145
# Validate on specific dataset
146
metrics = model.val(data="custom_dataset.yaml")
147
148
# Validate with custom parameters
149
metrics = model.val(
150
data="dataset.yaml",
151
split='test',
152
imgsz=1280,
153
conf=0.25,
154
iou=0.5,
155
save_json=True
156
)
157
158
# Access validation metrics
159
print(f"mAP50: {metrics.box.map50}")
160
print(f"mAP50-95: {metrics.box.map}")
161
print(f"Precision: {metrics.box.mp}")
162
print(f"Recall: {metrics.box.mr}")
163
```
164
165
### Hyperparameter Tuning
166
167
Automatically optimize hyperparameters using various search strategies.
168
169
```python { .api }
170
def tune(self, data=None, space=None, grace_period=10, gpu_per_trial=None, **kwargs) -> dict:
171
"""
172
Perform hyperparameter tuning using Ray Tune.
173
174
Parameters:
175
- data (str | Path): Path to dataset YAML file
176
- space (dict): Hyperparameter search space
177
- grace_period (int): Grace period for early stopping
178
- gpu_per_trial (float): GPU fraction per trial
179
- iterations (int): Number of tuning iterations (default: 10)
180
- **kwargs: Additional training arguments
181
182
Returns:
183
dict: Best hyperparameters and results
184
"""
185
```
186
187
**Usage Examples:**
188
189
```python
190
# Basic hyperparameter tuning
191
best_params = model.tune(data="dataset.yaml", iterations=30)
192
193
# Custom search space
194
search_space = {
195
'lr0': (0.0001, 0.01),
196
'momentum': (0.8, 0.95),
197
'weight_decay': (0.0001, 0.001),
198
'batch': [8, 16, 32]
199
}
200
201
best_params = model.tune(
202
data="dataset.yaml",
203
space=search_space,
204
iterations=50,
205
gpu_per_trial=0.5
206
)
207
```
208
209
### Training Callbacks
210
211
Customize training behavior with callback functions.
212
213
```python { .api }
214
def add_callback(self, event: str, callback):
215
"""
216
Add callback function for specific training event.
217
218
Parameters:
219
- event (str): Event name ('on_epoch_end', 'on_batch_end', etc.)
220
- callback: Callback function
221
"""
222
223
def clear_callback(self, event: str):
224
"""Clear all callbacks for specific event."""
225
226
def reset_callbacks(self):
227
"""Reset all callbacks to default functions."""
228
```
229
230
**Available Events:**
231
- `on_pretrain_routine_start`
232
- `on_pretrain_routine_end`
233
- `on_train_start`
234
- `on_train_epoch_start`
235
- `on_train_batch_start`
236
- `on_optimizer_step`
237
- `on_before_zero_grad`
238
- `on_train_batch_end`
239
- `on_train_epoch_end`
240
- `on_val_start`
241
- `on_val_batch_start`
242
- `on_val_batch_end`
243
- `on_val_end`
244
- `on_fit_epoch_end`
245
- `on_model_save`
246
- `on_train_end`
247
- `teardown`
248
249
**Usage Examples:**
250
251
```python
252
# Add custom callbacks
253
def log_predictions(trainer):
254
# Custom logging logic
255
print(f"Epoch {trainer.epoch}: Loss = {trainer.loss}")
256
257
def save_best_model(trainer):
258
if trainer.best_fitness == trainer.fitness:
259
trainer.model.save(f"best_model_epoch_{trainer.epoch}.pt")
260
261
model.add_callback('on_train_epoch_end', log_predictions)
262
model.add_callback('on_fit_epoch_end', save_best_model)
263
264
# Train with callbacks
265
results = model.train(data="dataset.yaml", epochs=100)
266
267
# Clear specific callback
268
model.clear_callback('on_train_epoch_end')
269
270
# Reset all callbacks
271
model.reset_callbacks()
272
```
273
274
## Training Data Format
275
276
### Dataset YAML Configuration
277
278
```yaml
279
# dataset.yaml
280
path: /path/to/dataset # dataset root dir
281
train: images/train # train images (relative to path)
282
val: images/val # val images (relative to path)
283
test: images/test # test images (optional)
284
285
# Classes
286
names:
287
0: person
288
1: bicycle
289
2: car
290
3: motorcycle
291
# ... more classes
292
```
293
294
### Directory Structure
295
296
```
297
dataset/
298
├── images/
299
│ ├── train/
300
│ │ ├── image1.jpg
301
│ │ ├── image2.jpg
302
│ │ └── ...
303
│ ├── val/
304
│ │ ├── image1.jpg
305
│ │ ├── image2.jpg
306
│ │ └── ...
307
│ └── test/
308
│ ├── image1.jpg
309
│ └── ...
310
└── labels/
311
├── train/
312
│ ├── image1.txt
313
│ ├── image2.txt
314
│ └── ...
315
├── val/
316
│ ├── image1.txt
317
│ ├── image2.txt
318
│ └── ...
319
└── test/
320
├── image1.txt
321
└── ...
322
```
323
324
## Types
325
326
```python { .api }
327
from typing import Dict, Any, Optional, Union, List, Callable
328
from pathlib import Path
329
330
# Training configuration types
331
TrainingConfig = Dict[str, Any]
332
ValidationMetrics = Dict[str, float]
333
CallbackFunction = Callable[['BaseTrainer'], None]
334
335
# Common metric types
336
class MetricsClass:
337
map: float # mAP@0.5:0.95
338
map50: float # mAP@0.5
339
map75: float # mAP@0.75
340
mp: float # mean precision
341
mr: float # mean recall
342
fitness: float # weighted combination of metrics
343
```