0
# Optimization
1
2
Warp provides gradient-based optimizers for machine learning workflows. These optimizers work seamlessly with Warp's differentiable kernels and integrate with automatic differentiation systems for training neural networks and optimizing physical simulations.
3
4
## Capabilities
5
6
### Adam Optimizer
7
8
Adaptive learning rate optimizer with momentum and bias correction.
9
10
```python { .api }
11
class Adam:
12
"""Adam optimizer for gradient-based optimization."""
13
14
def __init__(self,
15
params: list = None,
16
lr: float = 0.001,
17
betas: tuple = (0.9, 0.999),
18
eps: float = 1e-8):
19
"""
20
Initialize Adam optimizer.
21
22
Args:
23
params: List of arrays to optimize
24
lr: Learning rate
25
betas: Coefficients for running averages (beta1, beta2)
26
eps: Small constant for numerical stability
27
"""
28
29
def step(self) -> None:
30
"""
31
Perform single optimization step.
32
Updates parameters using accumulated gradients.
33
"""
34
35
def zero_grad(self) -> None:
36
"""Clear gradients of all parameters."""
37
38
@property
39
def state_dict(self) -> dict:
40
"""Get optimizer state for checkpointing."""
41
42
def load_state_dict(self, state_dict: dict) -> None:
43
"""Load optimizer state from checkpoint."""
44
```
45
46
### SGD Optimizer
47
48
Stochastic gradient descent with optional momentum and weight decay.
49
50
```python { .api }
51
class SGD:
52
"""Stochastic Gradient Descent optimizer."""
53
54
def __init__(self,
55
params: list = None,
56
lr: float = 0.001,
57
momentum: float = 0.0,
58
dampening: float = 0.0,
59
weight_decay: float = 0.0,
60
nesterov: bool = False):
61
"""
62
Initialize SGD optimizer.
63
64
Args:
65
params: List of arrays to optimize
66
lr: Learning rate
67
momentum: Momentum factor
68
dampening: Dampening for momentum
69
weight_decay: L2 regularization weight
70
nesterov: Enable Nesterov momentum
71
"""
72
73
def step(self) -> None:
74
"""Perform single optimization step."""
75
76
def zero_grad(self) -> None:
77
"""Clear gradients of all parameters."""
78
79
@property
80
def state_dict(self) -> dict:
81
"""Get optimizer state for checkpointing."""
82
83
def load_state_dict(self, state_dict: dict) -> None:
84
"""Load optimizer state from checkpoint."""
85
```
86
87
## Usage Examples
88
89
### Basic Neural Network Training
90
```python
91
import warp as wp
92
import numpy as np
93
94
# Define neural network parameters
95
W1 = wp.array(np.random.randn(784, 128).astype(np.float32), device='cuda', requires_grad=True)
96
b1 = wp.zeros(128, dtype=wp.float32, device='cuda', requires_grad=True)
97
W2 = wp.array(np.random.randn(128, 10).astype(np.float32), device='cuda', requires_grad=True)
98
b2 = wp.zeros(10, dtype=wp.float32, device='cuda', requires_grad=True)
99
100
# Create optimizer
101
optimizer = wp.optim.Adam([W1, b1, W2, b2], lr=0.001)
102
103
# Training loop
104
for epoch in range(100):
105
for batch_x, batch_y in data_loader:
106
# Convert to Warp arrays
107
x = wp.from_numpy(batch_x, device='cuda')
108
y_true = wp.from_numpy(batch_y, device='cuda')
109
110
# Forward pass using Warp kernels
111
h1 = forward_layer(x, W1, b1)
112
y_pred = forward_layer(h1, W2, b2)
113
114
# Compute loss
115
loss = compute_loss(y_pred, y_true)
116
117
# Backward pass (automatic differentiation)
118
wp.backward(loss)
119
120
# Update parameters
121
optimizer.step()
122
optimizer.zero_grad()
123
124
print(f"Epoch {epoch}, Loss: {loss.numpy()}")
125
```
126
127
### Physics Simulation Optimization
128
```python
129
import warp as wp
130
131
# Physical parameters to optimize
132
spring_stiffness = wp.array([1000.0], requires_grad=True, device='cuda')
133
damping_coeff = wp.array([0.1], requires_grad=True, device='cuda')
134
135
# Create optimizer
136
optimizer = wp.optim.Adam([spring_stiffness, damping_coeff], lr=0.01)
137
138
# Define physics simulation kernel
139
@wp.kernel
140
def simulate_springs(positions: wp.array(dtype=wp.vec3),
141
velocities: wp.array(dtype=wp.vec3),
142
forces: wp.array(dtype=wp.vec3),
143
stiffness: wp.array(dtype=float),
144
damping: wp.array(dtype=float),
145
dt: float):
146
i = wp.tid()
147
148
pos = positions[i]
149
vel = velocities[i]
150
151
# Spring force (to origin)
152
spring_force = -stiffness[0] * pos
153
154
# Damping force
155
damping_force = -damping[0] * vel
156
157
forces[i] = spring_force + damping_force
158
159
# Target trajectory
160
target_positions = wp.array(target_data, device='cuda')
161
162
# Optimization loop
163
for iteration in range(1000):
164
# Reset simulation state
165
positions = wp.copy(initial_positions)
166
velocities = wp.zeros_like(positions)
167
168
# Run simulation
169
for step in range(simulation_steps):
170
forces = wp.zeros_like(positions)
171
172
wp.launch(simulate_springs,
173
dim=num_particles,
174
inputs=[positions, velocities, forces,
175
spring_stiffness, damping_coeff, dt])
176
177
# Update positions and velocities
178
update_physics(positions, velocities, forces, dt)
179
180
# Compute loss against target
181
loss = wp.mean((positions - target_positions) ** 2)
182
183
# Backward pass
184
wp.backward(loss)
185
186
# Update parameters
187
optimizer.step()
188
optimizer.zero_grad()
189
190
if iteration % 100 == 0:
191
print(f"Iteration {iteration}, Loss: {loss.numpy()}")
192
```
193
194
### Custom Optimization with Multiple Optimizers
195
```python
196
import warp as wp
197
198
# Different parameter groups with different learning rates
199
fast_params = [weight_matrix] # High learning rate
200
slow_params = [bias_vector] # Low learning rate
201
202
# Create separate optimizers
203
fast_optimizer = wp.optim.Adam(fast_params, lr=0.01)
204
slow_optimizer = wp.optim.SGD(slow_params, lr=0.001, momentum=0.9)
205
206
# Training step
207
def training_step(loss):
208
# Compute gradients
209
wp.backward(loss)
210
211
# Update with different schedules
212
fast_optimizer.step()
213
slow_optimizer.step()
214
215
# Clear gradients
216
fast_optimizer.zero_grad()
217
slow_optimizer.zero_grad()
218
219
# Learning rate scheduling
220
def adjust_learning_rate(optimizer, epoch):
221
"""Decay learning rate by factor of 0.1 every 30 epochs."""
222
lr = optimizer.lr * (0.1 ** (epoch // 30))
223
for param_group in optimizer.param_groups:
224
param_group['lr'] = lr
225
```
226
227
### Checkpointing and State Management
228
```python
229
import warp as wp
230
import pickle
231
232
# Create optimizer
233
optimizer = wp.optim.Adam(model_params, lr=0.001)
234
235
# Training with checkpointing
236
for epoch in range(num_epochs):
237
# Training loop
238
for batch in data_loader:
239
loss = compute_loss(batch)
240
wp.backward(loss)
241
optimizer.step()
242
optimizer.zero_grad()
243
244
# Save checkpoint every 10 epochs
245
if epoch % 10 == 0:
246
checkpoint = {
247
'epoch': epoch,
248
'model_state': [param.numpy() for param in model_params],
249
'optimizer_state': optimizer.state_dict,
250
'loss': loss.numpy()
251
}
252
253
with open(f'checkpoint_epoch_{epoch}.pkl', 'wb') as f:
254
pickle.dump(checkpoint, f)
255
256
# Load checkpoint
257
def load_checkpoint(checkpoint_path, model_params, optimizer):
258
with open(checkpoint_path, 'rb') as f:
259
checkpoint = pickle.load(f)
260
261
# Restore model parameters
262
for param, saved_param in zip(model_params, checkpoint['model_state']):
263
param.assign(wp.from_numpy(saved_param, device=param.device))
264
265
# Restore optimizer state
266
optimizer.load_state_dict(checkpoint['optimizer_state'])
267
268
return checkpoint['epoch'], checkpoint['loss']
269
```
270
271
### Gradient Clipping and Regularization
272
```python
273
import warp as wp
274
275
class OptimizerWithClipping:
276
def __init__(self, optimizer, max_grad_norm=1.0):
277
self.optimizer = optimizer
278
self.max_grad_norm = max_grad_norm
279
280
def clip_gradients(self, parameters):
281
"""Clip gradients to prevent exploding gradients."""
282
# Compute total gradient norm
283
total_norm = 0.0
284
for param in parameters:
285
if param.grad is not None:
286
param_norm = wp.norm(param.grad)
287
total_norm += param_norm ** 2
288
289
total_norm = wp.sqrt(total_norm)
290
291
# Scale gradients if norm exceeds threshold
292
if total_norm > self.max_grad_norm:
293
clip_coef = self.max_grad_norm / (total_norm + 1e-6)
294
for param in parameters:
295
if param.grad is not None:
296
param.grad *= clip_coef
297
298
def step(self, parameters):
299
self.clip_gradients(parameters)
300
self.optimizer.step()
301
302
def zero_grad(self):
303
self.optimizer.zero_grad()
304
305
# Usage
306
base_optimizer = wp.optim.Adam(model_params, lr=0.001)
307
optimizer = OptimizerWithClipping(base_optimizer, max_grad_norm=1.0)
308
309
# Training with gradient clipping
310
for batch in data_loader:
311
loss = compute_loss(batch)
312
wp.backward(loss)
313
optimizer.step(model_params)
314
optimizer.zero_grad()
315
```
316
317
### Integration with Automatic Differentiation
318
```python
319
import warp as wp
320
321
# Enable tape for automatic differentiation
322
tape = wp.Tape()
323
324
# Forward pass with tape recording
325
with tape:
326
# Warp kernel computation
327
@wp.kernel
328
def neural_network_kernel(x: wp.array(dtype=float),
329
w: wp.array(dtype=float),
330
y: wp.array(dtype=float)):
331
i = wp.tid()
332
# Simple linear transformation
333
y[i] = w[0] * x[i] + w[1]
334
335
# Launch kernel
336
wp.launch(neural_network_kernel, dim=data_size,
337
inputs=[input_data, weights, output_data])
338
339
# Compute loss
340
loss = wp.mean((output_data - target_data) ** 2)
341
342
# Backward pass
343
tape.backward(loss)
344
345
# Extract gradients
346
weight_gradients = tape.gradients[weights]
347
348
# Manual optimizer step
349
learning_rate = 0.01
350
weights.assign(weights - learning_rate * weight_gradients)
351
352
# Reset tape for next iteration
353
tape.zero()
354
```
355
356
## Types
357
358
```python { .api }
359
# Optimizer base interface
360
class Optimizer:
361
"""Base class for optimizers."""
362
363
def __init__(self, params: list, lr: float):
364
"""Initialize optimizer with parameters and learning rate."""
365
366
def step(self) -> None:
367
"""Perform optimization step."""
368
369
def zero_grad(self) -> None:
370
"""Clear parameter gradients."""
371
372
@property
373
def param_groups(self) -> list:
374
"""List of parameter groups with optimization settings."""
375
376
@property
377
def state_dict(self) -> dict:
378
"""Optimizer state for checkpointing."""
379
380
# Parameter group structure
381
class ParameterGroup:
382
"""Group of parameters with shared optimization settings."""
383
384
params: list # List of parameter arrays
385
lr: float # Learning rate
386
weight_decay: float # L2 regularization
387
388
# Optimizer state for individual parameters
389
class ParameterState:
390
"""Per-parameter optimization state."""
391
392
step: int # Number of optimization steps
393
exp_avg: array # Exponential moving average of gradients (Adam)
394
exp_avg_sq: array # Exponential moving average of squared gradients (Adam)
395
momentum_buffer: array # Momentum buffer (SGD)
396
```