A gradient processing and optimization library in JAX
—
Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules. These schedules help optimize training dynamics and achieve better convergence.
def constant_schedule(value):
"""
Constant value schedule.
Args:
value: Constant value to return
Returns:
Schedule function
"""def linear_schedule(init_value, end_value, transition_steps):
"""
Linear interpolation between two values.
Args:
init_value: Initial value
end_value: Final value
transition_steps: Number of steps for transition
Returns:
Schedule function
"""def polynomial_schedule(init_value, end_value, power, transition_steps):
"""
Polynomial decay schedule.
Args:
init_value: Initial value
end_value: Final value
power: Polynomial power (1.0 = linear, 2.0 = quadratic, etc.)
transition_steps: Number of steps for transition
Returns:
Schedule function
"""def exponential_decay(init_value, decay_rate, transition_steps, transition_begin=0, staircase=False, end_value=None):
"""
Exponential decay schedule.
Args:
init_value: Initial value
decay_rate: Decay rate (e.g., 0.96 for 4% decay)
transition_steps: Steps between decay applications
transition_begin: Step to begin decay (default: 0)
staircase: Whether to apply decay in discrete steps (default: False)
end_value: Minimum value to decay to (default: None)
Returns:
Schedule function
"""def cosine_decay_schedule(init_value, decay_steps, alpha=0.0):
"""
Cosine decay schedule.
Args:
init_value: Initial value
decay_steps: Number of steps for full cosine cycle
alpha: Minimum value as fraction of init_value (default: 0.0)
Returns:
Schedule function
"""def cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
"""
One-cycle cosine schedule (warmup, decay, final decay).
Args:
transition_steps: Total number of steps
peak_value: Maximum value at peak
pct_start: Percentage of steps for warmup phase (default: 0.3)
pct_final: Percentage of steps before final decay (default: 0.85)
final_div_factor: Final value divisor (default: 1e4)
Returns:
Schedule function
"""def piecewise_constant_schedule(boundaries_and_scales):
"""
Piecewise constant schedule with different values in different intervals.
Args:
boundaries_and_scales: Dict mapping step boundaries to scale factors
Returns:
Schedule function
"""def piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales):
"""
Piecewise schedule with interpolation between boundaries.
Args:
interpolate_type: Type of interpolation ('linear', 'cosine')
init_value: Initial value
boundaries_and_scales: Dict mapping boundaries to scale factors
Returns:
Schedule function
"""def warmup_constant_schedule(init_value, peak_value, warmup_steps):
"""
Linear warmup followed by constant value.
Args:
init_value: Initial value during warmup
peak_value: Constant value after warmup
warmup_steps: Number of warmup steps
Returns:
Schedule function
"""def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0):
"""
Linear warmup followed by cosine decay.
Args:
init_value: Initial value during warmup
peak_value: Peak value after warmup
warmup_steps: Number of warmup steps
decay_steps: Number of decay steps after warmup
end_value: Final value after decay (default: 0.0)
Returns:
Schedule function
"""def warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None):
"""
Linear warmup followed by exponential decay.
Args:
init_value: Initial value during warmup
peak_value: Peak value after warmup
warmup_steps: Number of warmup steps
transition_steps: Steps between decay applications
decay_rate: Exponential decay rate
transition_begin: Step to begin decay (default: 0)
staircase: Whether to apply decay in discrete steps (default: False)
end_value: Minimum decay value (default: None)
Returns:
Schedule function
"""def linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):
"""
One-cycle linear schedule (warmup, decay, final decay).
Args:
transition_steps: Total number of steps
peak_value: Maximum value at peak
pct_start: Percentage of steps for warmup phase (default: 0.3)
pct_final: Percentage of steps before final decay (default: 0.85)
final_div_factor: Final value divisor (default: 1e4)
Returns:
Schedule function
"""def sgdr_schedule(cosine_decay_schedule, restart_period, t_mult=1.0):
"""
Stochastic Gradient Descent with Restarts (SGDR) schedule.
Args:
cosine_decay_schedule: Base cosine decay schedule
restart_period: Initial restart period
t_mult: Multiplier for restart period (default: 1.0)
Returns:
Schedule function
"""def join_schedules(schedules, boundaries):
"""
Join multiple schedules at specified boundaries.
Args:
schedules: List of schedule functions
boundaries: List of step boundaries for schedule transitions
Returns:
Combined schedule function
"""def inject_hyperparams(transformation, **scheduled_hyperparams):
"""
Inject scheduled hyperparameters into transformation.
Args:
transformation: Base gradient transformation
**scheduled_hyperparams: Named schedule functions for hyperparameters
Returns:
GradientTransformation with scheduled hyperparameters
"""def inject_stateful_hyperparams(transformation, **scheduled_hyperparams):
"""
Inject stateful scheduled hyperparameters into transformation.
Args:
transformation: Base gradient transformation
**scheduled_hyperparams: Named stateful schedule functions
Returns:
GradientTransformation with stateful scheduled hyperparameters
"""class InjectHyperparamsState:
"""State for hyperparameter injection."""
count: int
inner_state: OptState
class InjectStatefulHyperparamsState:
"""State for stateful hyperparameter injection."""
count: int
inner_state: OptState
hyperparams_states: dict
class WrappedSchedule:
"""Wrapper for schedule functions with state."""
schedule_fn: Scheduleimport optax
# Create different schedules
constant_lr = optax.constant_schedule(0.001)
linear_decay = optax.linear_schedule(0.001, 0.0001, 1000)
cosine_decay = optax.cosine_decay_schedule(0.001, 1000)
exponential_decay = optax.exponential_decay(0.001, 0.96, 100)
# Use schedule with optimizer
optimizer = optax.adam(learning_rate=cosine_decay)
# Evaluate schedule at different steps
step_0_lr = constant_lr(0) # 0.001
step_500_lr = linear_decay(500) # 0.0005
step_1000_lr = cosine_decay(1000) # close to 0# Warmup followed by cosine decay
warmup_cosine = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.001,
warmup_steps=1000,
decay_steps=9000,
end_value=0.00001
)
# Warmup followed by constant
warmup_constant = optax.warmup_constant_schedule(
init_value=0.0,
peak_value=0.001,
warmup_steps=500
)
# Use with optimizer
optimizer = optax.adamw(learning_rate=warmup_cosine, weight_decay=0.01)# Different learning rates at different training phases
boundaries_and_scales = {
500: 1.0, # LR = init_value * 1.0 until step 500
1000: 0.5, # LR = init_value * 0.5 from step 500-1000
1500: 0.1 # LR = init_value * 0.1 from step 1000-1500
}
piecewise_sched = optax.piecewise_constant_schedule(boundaries_and_scales)
# With interpolation
piecewise_interp = optax.piecewise_interpolate_schedule(
'linear', 0.001, boundaries_and_scales
)# One-cycle schedule
onecycle = optax.cosine_onecycle_schedule(
transition_steps=5000,
peak_value=0.01,
pct_start=0.3, # 30% warmup
pct_final=0.85 # 85% before final decay
)
# SGDR with restarts
base_cosine = optax.cosine_decay_schedule(0.001, 1000)
sgdr = optax.sgdr_schedule(base_cosine, restart_period=1000, t_mult=2.0)
# Join multiple schedules
schedules = [
optax.constant_schedule(0.001), # First 1000 steps
optax.linear_schedule(0.001, 0.0001, 1000) # Next 1000 steps
]
joined = optax.join_schedules(schedules, [1000])# Schedule multiple hyperparameters
base_transform = optax.scale_by_adam()
scheduled_transform = optax.inject_hyperparams(
base_transform,
learning_rate=optax.cosine_decay_schedule(0.001, 1000),
b1=optax.linear_schedule(0.9, 0.95, 500),
b2=optax.constant_schedule(0.999)
)
# Create complete optimizer
optimizer = optax.chain(
scheduled_transform,
optax.scale(-1.0) # Apply negative learning rate
)import jax
# Create schedule
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.001,
warmup_steps=1000,
decay_steps=9000
)
optimizer = optax.adam(learning_rate=schedule)
def train_step(params, opt_state, batch, step):
"""Training step with scheduled learning rate."""
def loss_fn(p):
return compute_loss(p, batch)
loss_val, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
# Current learning rate for logging
current_lr = schedule(step)
return params, opt_state, loss_val, current_lrInstall with Tessl CLI
npx tessl i tessl/pypi-optax