or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-distributed.mddata-processing.mddistributed-training.mdhyperparameter-tuning.mdindex.mdmodel-serving.mdreinforcement-learning.mdutilities-advanced.md

distributed-training.mddocs/

0

# Distributed Training

1

2

Ray Train provides distributed training capabilities for machine learning with support for PyTorch, TensorFlow, XGBoost, and other frameworks. It includes fault-tolerant training, automatic scaling, and seamless integration with Ray Data.

3

4

## Capabilities

5

6

### Core Training Framework

7

8

Base training functionality and configuration.

9

10

```python { .api }

11

class Trainer:

12

"""Base class for distributed training."""

13

14

def __init__(self, *, run_config=None, scaling_config=None, **kwargs):

15

"""

16

Initialize trainer.

17

18

Args:

19

run_config (RunConfig, optional): Run configuration

20

scaling_config (ScalingConfig, optional): Scaling configuration

21

"""

22

23

def fit(self, dataset=None):

24

"""

25

Execute training.

26

27

Args:

28

dataset (Dataset, optional): Training dataset

29

30

Returns:

31

Result: Training results

32

"""

33

34

def predict(self, dataset, *, checkpoint=None):

35

"""

36

Make predictions using trained model.

37

38

Args:

39

dataset (Dataset): Dataset for prediction

40

checkpoint (Checkpoint, optional): Model checkpoint

41

42

Returns:

43

Dataset: Predictions

44

"""

45

46

class RunConfig:

47

"""Configuration for training runs."""

48

49

def __init__(self, *, name=None, local_dir=None, stop=None,

50

checkpoint_config=None, verbose=None, **kwargs):

51

"""

52

Initialize run configuration.

53

54

Args:

55

name (str, optional): Run name

56

local_dir (str, optional): Local directory for results

57

stop (dict, optional): Stopping criteria

58

checkpoint_config (CheckpointConfig, optional): Checkpoint config

59

verbose (int, optional): Verbosity level

60

"""

61

62

class ScalingConfig:

63

"""Configuration for distributed scaling."""

64

65

def __init__(self, *, num_workers=None, use_gpu=False,

66

resources_per_worker=None, placement_strategy="PACK"):

67

"""

68

Initialize scaling configuration.

69

70

Args:

71

num_workers (int, optional): Number of workers

72

use_gpu (bool): Whether to use GPU

73

resources_per_worker (dict, optional): Resources per worker

74

placement_strategy (str): Worker placement strategy

75

"""

76

77

class CheckpointConfig:

78

"""Configuration for model checkpointing."""

79

80

def __init__(self, *, num_to_keep=None, checkpoint_score_attribute=None,

81

checkpoint_score_order="max"):

82

"""

83

Initialize checkpoint configuration.

84

85

Args:

86

num_to_keep (int, optional): Number of checkpoints to keep

87

checkpoint_score_attribute (str, optional): Metric to use for ranking

88

checkpoint_score_order (str): "max" or "min" for ranking

89

"""

90

```

91

92

### PyTorch Training

93

94

Distributed PyTorch training with automatic data parallelism.

95

96

```python { .api }

97

class TorchTrainer(Trainer):

98

"""Distributed PyTorch trainer."""

99

100

def __init__(self, train_loop_per_worker, *, train_loop_config=None,

101

torch_config=None, **kwargs):

102

"""

103

Initialize PyTorch trainer.

104

105

Args:

106

train_loop_per_worker: Training function to run on each worker

107

train_loop_config (dict, optional): Config passed to training function

108

torch_config (TorchConfig, optional): PyTorch-specific configuration

109

"""

110

111

class TorchConfig:

112

"""PyTorch-specific training configuration."""

113

114

def __init__(self, *, backend="nccl", init_method="env://",

115

timeout_s=1800):

116

"""

117

Initialize PyTorch configuration.

118

119

Args:

120

backend (str): Distributed backend ("nccl", "gloo")

121

init_method (str): Process group initialization method

122

timeout_s (int): Timeout for operations

123

"""

124

125

def get_device():

126

"""Get PyTorch device for current worker."""

127

128

def prepare_model(model, *, move_to_device=True, wrap_ddp=True):

129

"""

130

Prepare model for distributed training.

131

132

Args:

133

model: PyTorch model

134

move_to_device (bool): Move model to device

135

wrap_ddp (bool): Wrap with DistributedDataParallel

136

137

Returns:

138

Prepared model

139

"""

140

141

def prepare_data_loader(data_loader, *, add_dist_sampler=True):

142

"""

143

Prepare data loader for distributed training.

144

145

Args:

146

data_loader: PyTorch DataLoader

147

add_dist_sampler (bool): Add distributed sampler

148

149

Returns:

150

Prepared data loader

151

"""

152

153

def prepare_optimizer(optimizer):

154

"""

155

Prepare optimizer for distributed training.

156

157

Args:

158

optimizer: PyTorch optimizer

159

160

Returns:

161

Prepared optimizer

162

"""

163

164

class Checkpoint:

165

"""Training checkpoint."""

166

167

def __init__(self, *, data_dict=None, path=None):

168

"""

169

Initialize checkpoint.

170

171

Args:

172

data_dict (dict, optional): Checkpoint data

173

path (str, optional): Path to checkpoint

174

"""

175

176

@classmethod

177

def from_dict(cls, data):

178

"""Create checkpoint from dictionary."""

179

180

def to_dict(self):

181

"""Convert checkpoint to dictionary."""

182

183

def report(metrics, *, checkpoint=None):

184

"""

185

Report training metrics and optionally save checkpoint.

186

187

Args:

188

metrics (dict): Training metrics

189

checkpoint (Checkpoint, optional): Checkpoint to save

190

"""

191

```

192

193

### TensorFlow Training

194

195

Distributed TensorFlow training with MultiWorkerMirroredStrategy.

196

197

```python { .api }

198

class TensorflowTrainer(Trainer):

199

"""Distributed TensorFlow trainer."""

200

201

def __init__(self, train_loop_per_worker, *, train_loop_config=None,

202

tensorflow_config=None, **kwargs):

203

"""

204

Initialize TensorFlow trainer.

205

206

Args:

207

train_loop_per_worker: Training function to run on each worker

208

train_loop_config (dict, optional): Config passed to training function

209

tensorflow_config (TensorflowConfig, optional): TF-specific configuration

210

"""

211

212

class TensorflowConfig:

213

"""TensorFlow-specific training configuration."""

214

215

def __init__(self):

216

"""Initialize TensorFlow configuration."""

217

218

def setup_tensorflow_environment():

219

"""Setup TensorFlow distributed environment."""

220

221

def prepare_dataset_shard(tf_dataset):

222

"""

223

Prepare TensorFlow dataset for distributed training.

224

225

Args:

226

tf_dataset: TensorFlow dataset

227

228

Returns:

229

Sharded dataset

230

"""

231

```

232

233

### XGBoost Training

234

235

Distributed XGBoost training.

236

237

```python { .api }

238

class XGBoostTrainer(Trainer):

239

"""Distributed XGBoost trainer."""

240

241

def __init__(self, *, label_column, params=None, datasets=None,

242

**kwargs):

243

"""

244

Initialize XGBoost trainer.

245

246

Args:

247

label_column (str): Label column name

248

params (dict, optional): XGBoost parameters

249

datasets (dict, optional): Additional datasets (validation, etc.)

250

"""

251

252

class GBDTTrainer(Trainer):

253

"""Base class for gradient boosting trainers."""

254

255

def __init__(self, *, label_column, params=None, **kwargs):

256

"""

257

Initialize GBDT trainer.

258

259

Args:

260

label_column (str): Label column name

261

params (dict, optional): Training parameters

262

"""

263

264

class LightGBMTrainer(GBDTTrainer):

265

"""Distributed LightGBM trainer."""

266

267

class XGBoostConfig:

268

"""XGBoost-specific training configuration."""

269

270

def __init__(self, *, xgb_params=None, train_params=None):

271

"""

272

Initialize XGBoost configuration.

273

274

Args:

275

xgb_params (dict, optional): XGBoost model parameters

276

train_params (dict, optional): Training parameters

277

"""

278

```

279

280

### Hugging Face Integration

281

282

Integration with Hugging Face Transformers.

283

284

```python { .api }

285

class HuggingFaceTrainer(Trainer):

286

"""Distributed Hugging Face trainer."""

287

288

def __init__(self, *, trainer_init_per_worker, trainer_init_config=None,

289

**kwargs):

290

"""

291

Initialize Hugging Face trainer.

292

293

Args:

294

trainer_init_per_worker: Function to initialize HF trainer

295

trainer_init_config (dict, optional): Trainer initialization config

296

"""

297

298

class TransformersTrainer(HuggingFaceTrainer):

299

"""Transformers trainer (alias for HuggingFaceTrainer)."""

300

```

301

302

### Training Results and Checkpoints

303

304

Handle training results and model checkpoints.

305

306

```python { .api }

307

class Result:

308

"""Training result."""

309

310

@property

311

def metrics(self):

312

"""Training metrics."""

313

314

@property

315

def checkpoint(self):

316

"""Best checkpoint."""

317

318

@property

319

def path(self):

320

"""Result path."""

321

322

@property

323

def config(self):

324

"""Training configuration."""

325

326

class TorchCheckpoint:

327

"""PyTorch model checkpoint."""

328

329

@classmethod

330

def from_model(cls, model, *, preprocessor=None):

331

"""Create checkpoint from PyTorch model."""

332

333

def get_model(self, model_class=None):

334

"""Load PyTorch model from checkpoint."""

335

336

class TensorflowCheckpoint:

337

"""TensorFlow model checkpoint."""

338

339

@classmethod

340

def from_model(cls, model, *, preprocessor=None):

341

"""Create checkpoint from TensorFlow model."""

342

343

def get_model(self):

344

"""Load TensorFlow model from checkpoint."""

345

346

class XGBoostCheckpoint:

347

"""XGBoost model checkpoint."""

348

349

@classmethod

350

def from_model(cls, booster, *, preprocessor=None):

351

"""Create checkpoint from XGBoost booster."""

352

353

def get_model(self):

354

"""Load XGBoost booster from checkpoint."""

355

356

class DataParallelTrainer(Trainer):

357

"""Base class for data parallel trainers."""

358

359

def __init__(self, *, datasets=None, **kwargs):

360

"""

361

Initialize data parallel trainer.

362

363

Args:

364

datasets (dict, optional): Training datasets

365

"""

366

```

367

368

## Usage Examples

369

370

### PyTorch Training Example

371

372

```python

373

import ray

374

from ray import train

375

from ray.train import RunConfig, ScalingConfig

376

from ray.train.torch import TorchTrainer

377

import torch

378

import torch.nn as nn

379

380

ray.init()

381

382

def train_loop_per_worker(config):

383

# Define model

384

model = nn.Linear(1, 1)

385

model = train.torch.prepare_model(model)

386

387

# Define optimizer

388

optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])

389

optimizer = train.torch.prepare_optimizer(optimizer)

390

391

# Training loop

392

for epoch in range(config["num_epochs"]):

393

# Training logic here

394

loss = torch.tensor(0.1) # Placeholder

395

396

optimizer.zero_grad()

397

loss.backward()

398

optimizer.step()

399

400

# Report metrics

401

train.report({"loss": loss.item(), "epoch": epoch})

402

403

# Configure trainer

404

trainer = TorchTrainer(

405

train_loop_per_worker=train_loop_per_worker,

406

train_loop_config={"lr": 0.01, "num_epochs": 10},

407

scaling_config=ScalingConfig(num_workers=4, use_gpu=True),

408

run_config=RunConfig(name="torch_training")

409

)

410

411

# Execute training

412

result = trainer.fit()

413

print(f"Final metrics: {result.metrics}")

414

```

415

416

### XGBoost Training Example

417

418

```python

419

import ray

420

from ray import train

421

from ray.train.xgboost import XGBoostTrainer

422

423

ray.init()

424

425

# Load data

426

train_dataset = ray.data.read_csv("train.csv")

427

428

# Configure trainer

429

trainer = XGBoostTrainer(

430

label_column="target",

431

params={

432

"objective": "binary:logistic",

433

"learning_rate": 0.1,

434

"max_depth": 6

435

},

436

scaling_config=ScalingConfig(num_workers=4),

437

run_config=RunConfig(name="xgboost_training")

438

)

439

440

# Execute training

441

result = trainer.fit(dataset=train_dataset)

442

print(result.metrics)

443

444

# Make predictions

445

predictions = trainer.predict(test_dataset, checkpoint=result.checkpoint)

446

```

447

448

### TensorFlow Training Example

449

450

```python

451

import ray

452

from ray import train

453

from ray.train.tensorflow import TensorflowTrainer

454

import tensorflow as tf

455

456

ray.init()

457

458

def train_loop_per_worker(config):

459

# Setup distributed training

460

strategy = tf.distribute.MultiWorkerMirroredStrategy()

461

462

with strategy.scope():

463

# Define model

464

model = tf.keras.Sequential([

465

tf.keras.layers.Dense(64, activation='relu'),

466

tf.keras.layers.Dense(1)

467

])

468

469

model.compile(

470

optimizer='adam',

471

loss='mse',

472

metrics=['mae']

473

)

474

475

# Training loop

476

for epoch in range(config["num_epochs"]):

477

# Training logic here

478

history = model.fit(x_train, y_train, epochs=1, verbose=0)

479

480

# Report metrics

481

train.report({

482

"loss": history.history["loss"][0],

483

"mae": history.history["mae"][0],

484

"epoch": epoch

485

})

486

487

# Configure trainer

488

trainer = TensorflowTrainer(

489

train_loop_per_worker=train_loop_per_worker,

490

train_loop_config={"num_epochs": 10},

491

scaling_config=ScalingConfig(num_workers=2, use_gpu=True)

492

)

493

494

result = trainer.fit()

495

```