CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

schedules.mddocs/

Learning Rate Schedules

Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules. These schedules help optimize training dynamics and achieve better convergence.

Capabilities

Basic Schedules

Constant Schedule

def constant_schedule(value):
    """
    Constant value schedule.
    
    Args:
        value: Constant value to return
    
    Returns:
        Schedule function
    """

Linear Schedule

def linear_schedule(init_value, end_value, transition_steps):
    """
    Linear interpolation between two values.
    
    Args:
        init_value: Initial value
        end_value: Final value
        transition_steps: Number of steps for transition
    
    Returns:
        Schedule function
    """

Polynomial Schedule

def polynomial_schedule(init_value, end_value, power, transition_steps):
    """
    Polynomial decay schedule.
    
    Args:
        init_value: Initial value
        end_value: Final value
        power: Polynomial power (1.0 = linear, 2.0 = quadratic, etc.)
        transition_steps: Number of steps for transition
    
    Returns:
        Schedule function
    """

Exponential Decay

def exponential_decay(init_value, decay_rate, transition_steps, transition_begin=0, staircase=False, end_value=None):
    """
    Exponential decay schedule.
    
    Args:
        init_value: Initial value
        decay_rate: Decay rate (e.g., 0.96 for 4% decay)
        transition_steps: Steps between decay applications
        transition_begin: Step to begin decay (default: 0)
        staircase: Whether to apply decay in discrete steps (default: False)
        end_value: Minimum value to decay to (default: None)
    
    Returns:
        Schedule function
    """

Cosine Schedules

Cosine Decay

def cosine_decay_schedule(init_value, decay_steps, alpha=0.0):
    """
    Cosine decay schedule.
    
    Args:
        init_value: Initial value
        decay_steps: Number of steps for full cosine cycle
        alpha: Minimum value as fraction of init_value (default: 0.0)
    
    Returns:
        Schedule function
    """

Cosine One-Cycle

def cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
    """
    One-cycle cosine schedule (warmup, decay, final decay).
    
    Args:
        transition_steps: Total number of steps
        peak_value: Maximum value at peak
        pct_start: Percentage of steps for warmup phase (default: 0.3)
        pct_final: Percentage of steps before final decay (default: 0.85)
        final_div_factor: Final value divisor (default: 1e4)
    
    Returns:
        Schedule function
    """

Piecewise Schedules

Piecewise Constant

def piecewise_constant_schedule(boundaries_and_scales):
    """
    Piecewise constant schedule with different values in different intervals.
    
    Args:
        boundaries_and_scales: Dict mapping step boundaries to scale factors
    
    Returns:
        Schedule function
    """

Piecewise Interpolate

def piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales):
    """
    Piecewise schedule with interpolation between boundaries.
    
    Args:
        interpolate_type: Type of interpolation ('linear', 'cosine')
        init_value: Initial value
        boundaries_and_scales: Dict mapping boundaries to scale factors
    
    Returns:
        Schedule function
    """

Warmup Schedules

Warmup + Constant

def warmup_constant_schedule(init_value, peak_value, warmup_steps):
    """
    Linear warmup followed by constant value.
    
    Args:
        init_value: Initial value during warmup
        peak_value: Constant value after warmup
        warmup_steps: Number of warmup steps
    
    Returns:
        Schedule function
    """

Warmup + Cosine Decay

def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0):
    """
    Linear warmup followed by cosine decay.
    
    Args:
        init_value: Initial value during warmup
        peak_value: Peak value after warmup
        warmup_steps: Number of warmup steps
        decay_steps: Number of decay steps after warmup
        end_value: Final value after decay (default: 0.0)
    
    Returns:
        Schedule function
    """

Warmup + Exponential Decay

def warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None):
    """
    Linear warmup followed by exponential decay.
    
    Args:
        init_value: Initial value during warmup
        peak_value: Peak value after warmup
        warmup_steps: Number of warmup steps
        transition_steps: Steps between decay applications
        decay_rate: Exponential decay rate
        transition_begin: Step to begin decay (default: 0)
        staircase: Whether to apply decay in discrete steps (default: False)
        end_value: Minimum decay value (default: None)
    
    Returns:
        Schedule function
    """

Advanced Schedules

Linear One-Cycle

def linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
    """
    One-cycle linear schedule (warmup, decay, final decay).
    
    Args:
        transition_steps: Total number of steps
        peak_value: Maximum value at peak
        pct_start: Percentage of steps for warmup phase (default: 0.3)
        pct_final: Percentage of steps before final decay (default: 0.85)
        final_div_factor: Final value divisor (default: 1e4)
    
    Returns:
        Schedule function
    """

SGDR Schedule

def sgdr_schedule(cosine_decay_schedule, restart_period, t_mult=1.0):
    """
    Stochastic Gradient Descent with Restarts (SGDR) schedule.
    
    Args:
        cosine_decay_schedule: Base cosine decay schedule
        restart_period: Initial restart period
        t_mult: Multiplier for restart period (default: 1.0)
    
    Returns:
        Schedule function
    """

Schedule Composition

Join Schedules

def join_schedules(schedules, boundaries):
    """
    Join multiple schedules at specified boundaries.
    
    Args:
        schedules: List of schedule functions
        boundaries: List of step boundaries for schedule transitions
    
    Returns:
        Combined schedule function
    """

Hyperparameter Injection

Static Hyperparameters

def inject_hyperparams(transformation, **scheduled_hyperparams):
    """
    Inject scheduled hyperparameters into transformation.
    
    Args:
        transformation: Base gradient transformation
        **scheduled_hyperparams: Named schedule functions for hyperparameters
    
    Returns:
        GradientTransformation with scheduled hyperparameters
    """

Stateful Hyperparameters

def inject_stateful_hyperparams(transformation, **scheduled_hyperparams):
    """
    Inject stateful scheduled hyperparameters into transformation.
    
    Args:
        transformation: Base gradient transformation
        **scheduled_hyperparams: Named stateful schedule functions
    
    Returns:
        GradientTransformation with stateful scheduled hyperparameters
    """

Schedule State Classes

class InjectHyperparamsState:
    """State for hyperparameter injection."""
    count: int
    inner_state: OptState

class InjectStatefulHyperparamsState:
    """State for stateful hyperparameter injection."""
    count: int
    inner_state: OptState
    hyperparams_states: dict

class WrappedSchedule:
    """Wrapper for schedule functions with state."""
    schedule_fn: Schedule

Usage Examples

Basic Schedule Usage

import optax

# Create different schedules
constant_lr = optax.constant_schedule(0.001)
linear_decay = optax.linear_schedule(0.001, 0.0001, 1000)
cosine_decay = optax.cosine_decay_schedule(0.001, 1000)
exponential_decay = optax.exponential_decay(0.001, 0.96, 100)

# Use schedule with optimizer
optimizer = optax.adam(learning_rate=cosine_decay)

# Evaluate schedule at different steps
step_0_lr = constant_lr(0)      # 0.001
step_500_lr = linear_decay(500)  # 0.0005
step_1000_lr = cosine_decay(1000)  # close to 0

Warmup Schedules

# Warmup followed by cosine decay
warmup_cosine = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=9000,
    end_value=0.00001
)

# Warmup followed by constant
warmup_constant = optax.warmup_constant_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=500
)

# Use with optimizer
optimizer = optax.adamw(learning_rate=warmup_cosine, weight_decay=0.01)

Piecewise Schedules

# Different learning rates at different training phases
boundaries_and_scales = {
    500: 1.0,    # LR = init_value * 1.0 until step 500
    1000: 0.5,   # LR = init_value * 0.5 from step 500-1000
    1500: 0.1    # LR = init_value * 0.1 from step 1000-1500
}

piecewise_sched = optax.piecewise_constant_schedule(boundaries_and_scales)

# With interpolation
piecewise_interp = optax.piecewise_interpolate_schedule(
    'linear', 0.001, boundaries_and_scales
)

Advanced Scheduling

# One-cycle schedule
onecycle = optax.cosine_onecycle_schedule(
    transition_steps=5000,
    peak_value=0.01,
    pct_start=0.3,     # 30% warmup
    pct_final=0.85     # 85% before final decay
)

# SGDR with restarts
base_cosine = optax.cosine_decay_schedule(0.001, 1000)
sgdr = optax.sgdr_schedule(base_cosine, restart_period=1000, t_mult=2.0)

# Join multiple schedules
schedules = [
    optax.constant_schedule(0.001),     # First 1000 steps
    optax.linear_schedule(0.001, 0.0001, 1000)  # Next 1000 steps
]
joined = optax.join_schedules(schedules, [1000])

Hyperparameter Scheduling

# Schedule multiple hyperparameters
base_transform = optax.scale_by_adam()

scheduled_transform = optax.inject_hyperparams(
    base_transform,
    learning_rate=optax.cosine_decay_schedule(0.001, 1000),
    b1=optax.linear_schedule(0.9, 0.95, 500),
    b2=optax.constant_schedule(0.999)
)

# Create complete optimizer
optimizer = optax.chain(
    scheduled_transform,
    optax.scale(-1.0)  # Apply negative learning rate
)

Training Loop Integration

import jax

# Create schedule
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=9000
)

optimizer = optax.adam(learning_rate=schedule)

def train_step(params, opt_state, batch, step):
    """Training step with scheduled learning rate."""
    
    def loss_fn(p):
        return compute_loss(p, batch)
    
    loss_val, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    # Current learning rate for logging
    current_lr = schedule(step)
    
    return params, opt_state, loss_val, current_lr

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json