0
# Distributed Training
1
2
Ray Train provides distributed training capabilities for machine learning with support for PyTorch, TensorFlow, XGBoost, and other frameworks. It includes fault-tolerant training, automatic scaling, and seamless integration with Ray Data.
3
4
## Capabilities
5
6
### Core Training Framework
7
8
Base training functionality and configuration.
9
10
```python { .api }
11
class Trainer:
12
"""Base class for distributed training."""
13
14
def __init__(self, *, run_config=None, scaling_config=None, **kwargs):
15
"""
16
Initialize trainer.
17
18
Args:
19
run_config (RunConfig, optional): Run configuration
20
scaling_config (ScalingConfig, optional): Scaling configuration
21
"""
22
23
def fit(self, dataset=None):
24
"""
25
Execute training.
26
27
Args:
28
dataset (Dataset, optional): Training dataset
29
30
Returns:
31
Result: Training results
32
"""
33
34
def predict(self, dataset, *, checkpoint=None):
35
"""
36
Make predictions using trained model.
37
38
Args:
39
dataset (Dataset): Dataset for prediction
40
checkpoint (Checkpoint, optional): Model checkpoint
41
42
Returns:
43
Dataset: Predictions
44
"""
45
46
class RunConfig:
47
"""Configuration for training runs."""
48
49
def __init__(self, *, name=None, local_dir=None, stop=None,
50
checkpoint_config=None, verbose=None, **kwargs):
51
"""
52
Initialize run configuration.
53
54
Args:
55
name (str, optional): Run name
56
local_dir (str, optional): Local directory for results
57
stop (dict, optional): Stopping criteria
58
checkpoint_config (CheckpointConfig, optional): Checkpoint config
59
verbose (int, optional): Verbosity level
60
"""
61
62
class ScalingConfig:
63
"""Configuration for distributed scaling."""
64
65
def __init__(self, *, num_workers=None, use_gpu=False,
66
resources_per_worker=None, placement_strategy="PACK"):
67
"""
68
Initialize scaling configuration.
69
70
Args:
71
num_workers (int, optional): Number of workers
72
use_gpu (bool): Whether to use GPU
73
resources_per_worker (dict, optional): Resources per worker
74
placement_strategy (str): Worker placement strategy
75
"""
76
77
class CheckpointConfig:
78
"""Configuration for model checkpointing."""
79
80
def __init__(self, *, num_to_keep=None, checkpoint_score_attribute=None,
81
checkpoint_score_order="max"):
82
"""
83
Initialize checkpoint configuration.
84
85
Args:
86
num_to_keep (int, optional): Number of checkpoints to keep
87
checkpoint_score_attribute (str, optional): Metric to use for ranking
88
checkpoint_score_order (str): "max" or "min" for ranking
89
"""
90
```
91
92
### PyTorch Training
93
94
Distributed PyTorch training with automatic data parallelism.
95
96
```python { .api }
97
class TorchTrainer(Trainer):
98
"""Distributed PyTorch trainer."""
99
100
def __init__(self, train_loop_per_worker, *, train_loop_config=None,
101
torch_config=None, **kwargs):
102
"""
103
Initialize PyTorch trainer.
104
105
Args:
106
train_loop_per_worker: Training function to run on each worker
107
train_loop_config (dict, optional): Config passed to training function
108
torch_config (TorchConfig, optional): PyTorch-specific configuration
109
"""
110
111
class TorchConfig:
112
"""PyTorch-specific training configuration."""
113
114
def __init__(self, *, backend="nccl", init_method="env://",
115
timeout_s=1800):
116
"""
117
Initialize PyTorch configuration.
118
119
Args:
120
backend (str): Distributed backend ("nccl", "gloo")
121
init_method (str): Process group initialization method
122
timeout_s (int): Timeout for operations
123
"""
124
125
def get_device():
126
"""Get PyTorch device for current worker."""
127
128
def prepare_model(model, *, move_to_device=True, wrap_ddp=True):
129
"""
130
Prepare model for distributed training.
131
132
Args:
133
model: PyTorch model
134
move_to_device (bool): Move model to device
135
wrap_ddp (bool): Wrap with DistributedDataParallel
136
137
Returns:
138
Prepared model
139
"""
140
141
def prepare_data_loader(data_loader, *, add_dist_sampler=True):
142
"""
143
Prepare data loader for distributed training.
144
145
Args:
146
data_loader: PyTorch DataLoader
147
add_dist_sampler (bool): Add distributed sampler
148
149
Returns:
150
Prepared data loader
151
"""
152
153
def prepare_optimizer(optimizer):
154
"""
155
Prepare optimizer for distributed training.
156
157
Args:
158
optimizer: PyTorch optimizer
159
160
Returns:
161
Prepared optimizer
162
"""
163
164
class Checkpoint:
165
"""Training checkpoint."""
166
167
def __init__(self, *, data_dict=None, path=None):
168
"""
169
Initialize checkpoint.
170
171
Args:
172
data_dict (dict, optional): Checkpoint data
173
path (str, optional): Path to checkpoint
174
"""
175
176
@classmethod
177
def from_dict(cls, data):
178
"""Create checkpoint from dictionary."""
179
180
def to_dict(self):
181
"""Convert checkpoint to dictionary."""
182
183
def report(metrics, *, checkpoint=None):
184
"""
185
Report training metrics and optionally save checkpoint.
186
187
Args:
188
metrics (dict): Training metrics
189
checkpoint (Checkpoint, optional): Checkpoint to save
190
"""
191
```
192
193
### TensorFlow Training
194
195
Distributed TensorFlow training with MultiWorkerMirroredStrategy.
196
197
```python { .api }
198
class TensorflowTrainer(Trainer):
199
"""Distributed TensorFlow trainer."""
200
201
def __init__(self, train_loop_per_worker, *, train_loop_config=None,
202
tensorflow_config=None, **kwargs):
203
"""
204
Initialize TensorFlow trainer.
205
206
Args:
207
train_loop_per_worker: Training function to run on each worker
208
train_loop_config (dict, optional): Config passed to training function
209
tensorflow_config (TensorflowConfig, optional): TF-specific configuration
210
"""
211
212
class TensorflowConfig:
213
"""TensorFlow-specific training configuration."""
214
215
def __init__(self):
216
"""Initialize TensorFlow configuration."""
217
218
def setup_tensorflow_environment():
219
"""Setup TensorFlow distributed environment."""
220
221
def prepare_dataset_shard(tf_dataset):
222
"""
223
Prepare TensorFlow dataset for distributed training.
224
225
Args:
226
tf_dataset: TensorFlow dataset
227
228
Returns:
229
Sharded dataset
230
"""
231
```
232
233
### XGBoost Training
234
235
Distributed XGBoost training.
236
237
```python { .api }
238
class XGBoostTrainer(Trainer):
239
"""Distributed XGBoost trainer."""
240
241
def __init__(self, *, label_column, params=None, datasets=None,
242
**kwargs):
243
"""
244
Initialize XGBoost trainer.
245
246
Args:
247
label_column (str): Label column name
248
params (dict, optional): XGBoost parameters
249
datasets (dict, optional): Additional datasets (validation, etc.)
250
"""
251
252
class GBDTTrainer(Trainer):
253
"""Base class for gradient boosting trainers."""
254
255
def __init__(self, *, label_column, params=None, **kwargs):
256
"""
257
Initialize GBDT trainer.
258
259
Args:
260
label_column (str): Label column name
261
params (dict, optional): Training parameters
262
"""
263
264
class LightGBMTrainer(GBDTTrainer):
265
"""Distributed LightGBM trainer."""
266
267
class XGBoostConfig:
268
"""XGBoost-specific training configuration."""
269
270
def __init__(self, *, xgb_params=None, train_params=None):
271
"""
272
Initialize XGBoost configuration.
273
274
Args:
275
xgb_params (dict, optional): XGBoost model parameters
276
train_params (dict, optional): Training parameters
277
"""
278
```
279
280
### Hugging Face Integration
281
282
Integration with Hugging Face Transformers.
283
284
```python { .api }
285
class HuggingFaceTrainer(Trainer):
286
"""Distributed Hugging Face trainer."""
287
288
def __init__(self, *, trainer_init_per_worker, trainer_init_config=None,
289
**kwargs):
290
"""
291
Initialize Hugging Face trainer.
292
293
Args:
294
trainer_init_per_worker: Function to initialize HF trainer
295
trainer_init_config (dict, optional): Trainer initialization config
296
"""
297
298
class TransformersTrainer(HuggingFaceTrainer):
299
"""Transformers trainer (alias for HuggingFaceTrainer)."""
300
```
301
302
### Training Results and Checkpoints
303
304
Handle training results and model checkpoints.
305
306
```python { .api }
307
class Result:
308
"""Training result."""
309
310
@property
311
def metrics(self):
312
"""Training metrics."""
313
314
@property
315
def checkpoint(self):
316
"""Best checkpoint."""
317
318
@property
319
def path(self):
320
"""Result path."""
321
322
@property
323
def config(self):
324
"""Training configuration."""
325
326
class TorchCheckpoint:
327
"""PyTorch model checkpoint."""
328
329
@classmethod
330
def from_model(cls, model, *, preprocessor=None):
331
"""Create checkpoint from PyTorch model."""
332
333
def get_model(self, model_class=None):
334
"""Load PyTorch model from checkpoint."""
335
336
class TensorflowCheckpoint:
337
"""TensorFlow model checkpoint."""
338
339
@classmethod
340
def from_model(cls, model, *, preprocessor=None):
341
"""Create checkpoint from TensorFlow model."""
342
343
def get_model(self):
344
"""Load TensorFlow model from checkpoint."""
345
346
class XGBoostCheckpoint:
347
"""XGBoost model checkpoint."""
348
349
@classmethod
350
def from_model(cls, booster, *, preprocessor=None):
351
"""Create checkpoint from XGBoost booster."""
352
353
def get_model(self):
354
"""Load XGBoost booster from checkpoint."""
355
356
class DataParallelTrainer(Trainer):
357
"""Base class for data parallel trainers."""
358
359
def __init__(self, *, datasets=None, **kwargs):
360
"""
361
Initialize data parallel trainer.
362
363
Args:
364
datasets (dict, optional): Training datasets
365
"""
366
```
367
368
## Usage Examples
369
370
### PyTorch Training Example
371
372
```python
373
import ray
374
from ray import train
375
from ray.train import RunConfig, ScalingConfig
376
from ray.train.torch import TorchTrainer
377
import torch
378
import torch.nn as nn
379
380
ray.init()
381
382
def train_loop_per_worker(config):
383
# Define model
384
model = nn.Linear(1, 1)
385
model = train.torch.prepare_model(model)
386
387
# Define optimizer
388
optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
389
optimizer = train.torch.prepare_optimizer(optimizer)
390
391
# Training loop
392
for epoch in range(config["num_epochs"]):
393
# Training logic here
394
loss = torch.tensor(0.1) # Placeholder
395
396
optimizer.zero_grad()
397
loss.backward()
398
optimizer.step()
399
400
# Report metrics
401
train.report({"loss": loss.item(), "epoch": epoch})
402
403
# Configure trainer
404
trainer = TorchTrainer(
405
train_loop_per_worker=train_loop_per_worker,
406
train_loop_config={"lr": 0.01, "num_epochs": 10},
407
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
408
run_config=RunConfig(name="torch_training")
409
)
410
411
# Execute training
412
result = trainer.fit()
413
print(f"Final metrics: {result.metrics}")
414
```
415
416
### XGBoost Training Example
417
418
```python
419
import ray
420
from ray import train
421
from ray.train.xgboost import XGBoostTrainer
422
423
ray.init()
424
425
# Load data
426
train_dataset = ray.data.read_csv("train.csv")
427
428
# Configure trainer
429
trainer = XGBoostTrainer(
430
label_column="target",
431
params={
432
"objective": "binary:logistic",
433
"learning_rate": 0.1,
434
"max_depth": 6
435
},
436
scaling_config=ScalingConfig(num_workers=4),
437
run_config=RunConfig(name="xgboost_training")
438
)
439
440
# Execute training
441
result = trainer.fit(dataset=train_dataset)
442
print(result.metrics)
443
444
# Make predictions
445
predictions = trainer.predict(test_dataset, checkpoint=result.checkpoint)
446
```
447
448
### TensorFlow Training Example
449
450
```python
451
import ray
452
from ray import train
453
from ray.train.tensorflow import TensorflowTrainer
454
import tensorflow as tf
455
456
ray.init()
457
458
def train_loop_per_worker(config):
459
# Setup distributed training
460
strategy = tf.distribute.MultiWorkerMirroredStrategy()
461
462
with strategy.scope():
463
# Define model
464
model = tf.keras.Sequential([
465
tf.keras.layers.Dense(64, activation='relu'),
466
tf.keras.layers.Dense(1)
467
])
468
469
model.compile(
470
optimizer='adam',
471
loss='mse',
472
metrics=['mae']
473
)
474
475
# Training loop
476
for epoch in range(config["num_epochs"]):
477
# Training logic here
478
history = model.fit(x_train, y_train, epochs=1, verbose=0)
479
480
# Report metrics
481
train.report({
482
"loss": history.history["loss"][0],
483
"mae": history.history["mae"][0],
484
"epoch": epoch
485
})
486
487
# Configure trainer
488
trainer = TensorflowTrainer(
489
train_loop_per_worker=train_loop_per_worker,
490
train_loop_config={"num_epochs": 10},
491
scaling_config=ScalingConfig(num_workers=2, use_gpu=True)
492
)
493
494
result = trainer.fit()
495
```