Ray is a unified framework for scaling AI and Python applications.
—
Ray Train provides distributed training capabilities for machine learning with support for PyTorch, TensorFlow, XGBoost, and other frameworks. It includes fault-tolerant training, automatic scaling, and seamless integration with Ray Data.
Base training functionality and configuration.
class Trainer:
"""Base class for distributed training."""
def __init__(self, *, run_config=None, scaling_config=None, **kwargs):
"""
Initialize trainer.
Args:
run_config (RunConfig, optional): Run configuration
scaling_config (ScalingConfig, optional): Scaling configuration
"""
def fit(self, dataset=None):
"""
Execute training.
Args:
dataset (Dataset, optional): Training dataset
Returns:
Result: Training results
"""
def predict(self, dataset, *, checkpoint=None):
"""
Make predictions using trained model.
Args:
dataset (Dataset): Dataset for prediction
checkpoint (Checkpoint, optional): Model checkpoint
Returns:
Dataset: Predictions
"""
class RunConfig:
"""Configuration for training runs."""
def __init__(self, *, name=None, local_dir=None, stop=None,
checkpoint_config=None, verbose=None, **kwargs):
"""
Initialize run configuration.
Args:
name (str, optional): Run name
local_dir (str, optional): Local directory for results
stop (dict, optional): Stopping criteria
checkpoint_config (CheckpointConfig, optional): Checkpoint config
verbose (int, optional): Verbosity level
"""
class ScalingConfig:
"""Configuration for distributed scaling."""
def __init__(self, *, num_workers=None, use_gpu=False,
resources_per_worker=None, placement_strategy="PACK"):
"""
Initialize scaling configuration.
Args:
num_workers (int, optional): Number of workers
use_gpu (bool): Whether to use GPU
resources_per_worker (dict, optional): Resources per worker
placement_strategy (str): Worker placement strategy
"""
class CheckpointConfig:
"""Configuration for model checkpointing."""
def __init__(self, *, num_to_keep=None, checkpoint_score_attribute=None,
checkpoint_score_order="max"):
"""
Initialize checkpoint configuration.
Args:
num_to_keep (int, optional): Number of checkpoints to keep
checkpoint_score_attribute (str, optional): Metric to use for ranking
checkpoint_score_order (str): "max" or "min" for ranking
"""Distributed PyTorch training with automatic data parallelism.
class TorchTrainer(Trainer):
"""Distributed PyTorch trainer."""
def __init__(self, train_loop_per_worker, *, train_loop_config=None,
torch_config=None, **kwargs):
"""
Initialize PyTorch trainer.
Args:
train_loop_per_worker: Training function to run on each worker
train_loop_config (dict, optional): Config passed to training function
torch_config (TorchConfig, optional): PyTorch-specific configuration
"""
class TorchConfig:
"""PyTorch-specific training configuration."""
def __init__(self, *, backend="nccl", init_method="env://",
timeout_s=1800):
"""
Initialize PyTorch configuration.
Args:
backend (str): Distributed backend ("nccl", "gloo")
init_method (str): Process group initialization method
timeout_s (int): Timeout for operations
"""
def get_device():
"""Get PyTorch device for current worker."""
def prepare_model(model, *, move_to_device=True, wrap_ddp=True):
"""
Prepare model for distributed training.
Args:
model: PyTorch model
move_to_device (bool): Move model to device
wrap_ddp (bool): Wrap with DistributedDataParallel
Returns:
Prepared model
"""
def prepare_data_loader(data_loader, *, add_dist_sampler=True):
"""
Prepare data loader for distributed training.
Args:
data_loader: PyTorch DataLoader
add_dist_sampler (bool): Add distributed sampler
Returns:
Prepared data loader
"""
def prepare_optimizer(optimizer):
"""
Prepare optimizer for distributed training.
Args:
optimizer: PyTorch optimizer
Returns:
Prepared optimizer
"""
class Checkpoint:
"""Training checkpoint."""
def __init__(self, *, data_dict=None, path=None):
"""
Initialize checkpoint.
Args:
data_dict (dict, optional): Checkpoint data
path (str, optional): Path to checkpoint
"""
@classmethod
def from_dict(cls, data):
"""Create checkpoint from dictionary."""
def to_dict(self):
"""Convert checkpoint to dictionary."""
def report(metrics, *, checkpoint=None):
"""
Report training metrics and optionally save checkpoint.
Args:
metrics (dict): Training metrics
checkpoint (Checkpoint, optional): Checkpoint to save
"""Distributed TensorFlow training with MultiWorkerMirroredStrategy.
class TensorflowTrainer(Trainer):
"""Distributed TensorFlow trainer."""
def __init__(self, train_loop_per_worker, *, train_loop_config=None,
tensorflow_config=None, **kwargs):
"""
Initialize TensorFlow trainer.
Args:
train_loop_per_worker: Training function to run on each worker
train_loop_config (dict, optional): Config passed to training function
tensorflow_config (TensorflowConfig, optional): TF-specific configuration
"""
class TensorflowConfig:
"""TensorFlow-specific training configuration."""
def __init__(self):
"""Initialize TensorFlow configuration."""
def setup_tensorflow_environment():
"""Setup TensorFlow distributed environment."""
def prepare_dataset_shard(tf_dataset):
"""
Prepare TensorFlow dataset for distributed training.
Args:
tf_dataset: TensorFlow dataset
Returns:
Sharded dataset
"""Distributed XGBoost training.
class XGBoostTrainer(Trainer):
"""Distributed XGBoost trainer."""
def __init__(self, *, label_column, params=None, datasets=None,
**kwargs):
"""
Initialize XGBoost trainer.
Args:
label_column (str): Label column name
params (dict, optional): XGBoost parameters
datasets (dict, optional): Additional datasets (validation, etc.)
"""
class GBDTTrainer(Trainer):
"""Base class for gradient boosting trainers."""
def __init__(self, *, label_column, params=None, **kwargs):
"""
Initialize GBDT trainer.
Args:
label_column (str): Label column name
params (dict, optional): Training parameters
"""
class LightGBMTrainer(GBDTTrainer):
"""Distributed LightGBM trainer."""
class XGBoostConfig:
"""XGBoost-specific training configuration."""
def __init__(self, *, xgb_params=None, train_params=None):
"""
Initialize XGBoost configuration.
Args:
xgb_params (dict, optional): XGBoost model parameters
train_params (dict, optional): Training parameters
"""Integration with Hugging Face Transformers.
class HuggingFaceTrainer(Trainer):
"""Distributed Hugging Face trainer."""
def __init__(self, *, trainer_init_per_worker, trainer_init_config=None,
**kwargs):
"""
Initialize Hugging Face trainer.
Args:
trainer_init_per_worker: Function to initialize HF trainer
trainer_init_config (dict, optional): Trainer initialization config
"""
class TransformersTrainer(HuggingFaceTrainer):
"""Transformers trainer (alias for HuggingFaceTrainer)."""Handle training results and model checkpoints.
class Result:
"""Training result."""
@property
def metrics(self):
"""Training metrics."""
@property
def checkpoint(self):
"""Best checkpoint."""
@property
def path(self):
"""Result path."""
@property
def config(self):
"""Training configuration."""
class TorchCheckpoint:
"""PyTorch model checkpoint."""
@classmethod
def from_model(cls, model, *, preprocessor=None):
"""Create checkpoint from PyTorch model."""
def get_model(self, model_class=None):
"""Load PyTorch model from checkpoint."""
class TensorflowCheckpoint:
"""TensorFlow model checkpoint."""
@classmethod
def from_model(cls, model, *, preprocessor=None):
"""Create checkpoint from TensorFlow model."""
def get_model(self):
"""Load TensorFlow model from checkpoint."""
class XGBoostCheckpoint:
"""XGBoost model checkpoint."""
@classmethod
def from_model(cls, booster, *, preprocessor=None):
"""Create checkpoint from XGBoost booster."""
def get_model(self):
"""Load XGBoost booster from checkpoint."""
class DataParallelTrainer(Trainer):
"""Base class for data parallel trainers."""
def __init__(self, *, datasets=None, **kwargs):
"""
Initialize data parallel trainer.
Args:
datasets (dict, optional): Training datasets
"""import ray
from ray import train
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
ray.init()
def train_loop_per_worker(config):
# Define model
model = nn.Linear(1, 1)
model = train.torch.prepare_model(model)
# Define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
optimizer = train.torch.prepare_optimizer(optimizer)
# Training loop
for epoch in range(config["num_epochs"]):
# Training logic here
loss = torch.tensor(0.1) # Placeholder
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Report metrics
train.report({"loss": loss.item(), "epoch": epoch})
# Configure trainer
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"lr": 0.01, "num_epochs": 10},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
run_config=RunConfig(name="torch_training")
)
# Execute training
result = trainer.fit()
print(f"Final metrics: {result.metrics}")import ray
from ray import train
from ray.train.xgboost import XGBoostTrainer
ray.init()
# Load data
train_dataset = ray.data.read_csv("train.csv")
# Configure trainer
trainer = XGBoostTrainer(
label_column="target",
params={
"objective": "binary:logistic",
"learning_rate": 0.1,
"max_depth": 6
},
scaling_config=ScalingConfig(num_workers=4),
run_config=RunConfig(name="xgboost_training")
)
# Execute training
result = trainer.fit(dataset=train_dataset)
print(result.metrics)
# Make predictions
predictions = trainer.predict(test_dataset, checkpoint=result.checkpoint)import ray
from ray import train
from ray.train.tensorflow import TensorflowTrainer
import tensorflow as tf
ray.init()
def train_loop_per_worker(config):
# Setup distributed training
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Define model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae']
)
# Training loop
for epoch in range(config["num_epochs"]):
# Training logic here
history = model.fit(x_train, y_train, epochs=1, verbose=0)
# Report metrics
train.report({
"loss": history.history["loss"][0],
"mae": history.history["mae"][0],
"epoch": epoch
})
# Configure trainer
trainer = TensorflowTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"num_epochs": 10},
scaling_config=ScalingConfig(num_workers=2, use_gpu=True)
)
result = trainer.fit()Install with Tessl CLI
npx tessl i tessl/pypi-ray