0
# Function Decorators & Helpers
1
2
Modal provides specialized decorators and helper functions for enhancing function behavior, defining class lifecycle methods, enabling batched execution, and controlling concurrency. These tools allow fine-grained control over how functions execute in the Modal environment.
3
4
## Capabilities
5
6
### Method Decorator
7
8
Decorator for defining methods within Modal classes, enabling stateful serverless computing with shared instance state.
9
10
```python { .api }
11
def method(func: Callable) -> Callable:
12
"""Decorator to define methods within Modal classes"""
13
```
14
15
#### Usage Examples
16
17
```python
18
import modal
19
20
app = modal.App()
21
22
@app.cls()
23
class DataProcessor:
24
def __init__(self, model_path: str):
25
# Constructor runs during instance creation
26
self.model = load_model(model_path)
27
self.cache = {}
28
29
@modal.method()
30
def process_single(self, data: str) -> str:
31
# Method can access instance state
32
if data in self.cache:
33
return self.cache[data]
34
35
result = self.model.predict(data)
36
self.cache[data] = result
37
return result
38
39
@modal.method()
40
def process_batch(self, data_list: list[str]) -> list[str]:
41
# Another method sharing the same instance state
42
return [self.process_single(data) for data in data_list]
43
44
@modal.method()
45
def get_cache_size(self) -> int:
46
return len(self.cache)
47
48
# Usage
49
@app.local_entrypoint()
50
def main():
51
processor = DataProcessor("path/to/model")
52
53
# Call methods on the remote instance
54
result1 = processor.process_single.remote("input1")
55
result2 = processor.process_batch.remote(["input2", "input3"])
56
cache_size = processor.get_cache_size.remote()
57
58
print(f"Results: {result1}, {result2}")
59
print(f"Cache size: {cache_size}")
60
```
61
62
### Parameter Helper
63
64
Helper function for defining class initialization parameters with validation and default values, similar to dataclass fields.
65
66
```python { .api }
67
def parameter(*, default: Any = _no_default, init: bool = True) -> Any:
68
"""Define class initialization parameters with options"""
69
```
70
71
#### Usage Examples
72
73
```python
74
import modal
75
76
app = modal.App()
77
78
@app.cls()
79
class ConfigurableService:
80
# Parameters with type annotations and defaults
81
model_name: str = modal.parameter()
82
batch_size: int = modal.parameter(default=32)
83
temperature: float = modal.parameter(default=0.7)
84
debug_mode: bool = modal.parameter(default=False)
85
86
# Internal field not used in constructor
87
_internal_cache: dict = modal.parameter(init=False)
88
89
def __post_init__(self):
90
# Initialize internal state after parameter injection
91
self._internal_cache = {}
92
print(f"Service initialized with model={self.model_name}, batch_size={self.batch_size}")
93
94
@modal.method()
95
def configure_service(self):
96
# Use parameters in methods
97
if self.debug_mode:
98
print(f"Debug: Processing with temperature={self.temperature}")
99
100
return {
101
"model": self.model_name,
102
"batch_size": self.batch_size,
103
"temperature": self.temperature
104
}
105
106
# Usage with different configurations
107
@app.local_entrypoint()
108
def main():
109
# Create instances with different parameters
110
service1 = ConfigurableService("gpt-4", batch_size=64, debug_mode=True)
111
service2 = ConfigurableService("claude-3", temperature=0.5)
112
113
config1 = service1.configure_service.remote()
114
config2 = service2.configure_service.remote()
115
116
print("Service 1 config:", config1)
117
print("Service 2 config:", config2)
118
```
119
120
### Lifecycle Decorators
121
122
Decorators for defining class lifecycle methods that run during container startup and shutdown.
123
124
```python { .api }
125
def enter(func: Callable) -> Callable:
126
"""Decorator for class enter lifecycle method (runs on container startup)"""
127
128
def exit(func: Callable) -> Callable:
129
"""Decorator for class exit lifecycle method (runs on container shutdown)"""
130
```
131
132
#### Usage Examples
133
134
```python
135
import modal
136
137
app = modal.App()
138
139
@app.cls()
140
class DatabaseService:
141
def __init__(self, connection_string: str):
142
self.connection_string = connection_string
143
self.connection = None
144
self.cache = None
145
146
@modal.enter()
147
def setup_connections(self):
148
"""Run once when container starts"""
149
print("Setting up database connection...")
150
self.connection = create_database_connection(self.connection_string)
151
self.cache = initialize_cache()
152
print("Database service ready!")
153
154
@modal.exit()
155
def cleanup_connections(self):
156
"""Run once when container shuts down"""
157
print("Cleaning up database connections...")
158
if self.connection:
159
self.connection.close()
160
if self.cache:
161
self.cache.clear()
162
print("Cleanup complete!")
163
164
@modal.method()
165
def query_data(self, sql: str) -> list[dict]:
166
# Connection is already established from enter()
167
cursor = self.connection.cursor()
168
cursor.execute(sql)
169
return cursor.fetchall()
170
171
@modal.method()
172
def cached_query(self, sql: str) -> list[dict]:
173
# Use cache initialized in enter()
174
if sql in self.cache:
175
return self.cache[sql]
176
177
result = self.query_data(sql)
178
self.cache[sql] = result
179
return result
180
181
# Usage
182
@app.local_entrypoint()
183
def main():
184
db_service = DatabaseService("postgresql://user:pass@host:5432/db")
185
186
# First call triggers enter() lifecycle
187
results = db_service.query_data.remote("SELECT * FROM users LIMIT 10")
188
189
# Subsequent calls reuse the established connection
190
cached_results = db_service.cached_query.remote("SELECT COUNT(*) FROM users")
191
192
print("Query results:", results)
193
print("Cached results:", cached_results)
194
195
# Container shutdown triggers exit() lifecycle automatically
196
```
197
198
### Execution Control Decorators
199
200
Decorators for controlling how functions execute, including batching and concurrency patterns.
201
202
```python { .api }
203
def batched(max_batch_size: int = 10) -> Callable:
204
"""Decorator to enable batched function calls for improved throughput"""
205
206
def concurrent(func: Callable) -> Callable:
207
"""Decorator to enable concurrent function execution"""
208
```
209
210
#### Usage Examples
211
212
```python
213
import modal
214
215
app = modal.App()
216
217
@app.function()
218
@modal.batched(max_batch_size=50)
219
def process_items_batched(items: list[str]) -> list[str]:
220
"""Process multiple items in a single function call"""
221
print(f"Processing batch of {len(items)} items")
222
223
# Expensive setup that benefits from batching
224
model = load_expensive_model()
225
226
# Process all items in the batch
227
results = []
228
for item in items:
229
result = model.process(item)
230
results.append(result)
231
232
return results
233
234
@app.function()
235
@modal.concurrent
236
def process_item_concurrent(item: str) -> str:
237
"""Process items with concurrent execution"""
238
# Each call can run concurrently with others
239
return expensive_processing(item)
240
241
@app.local_entrypoint()
242
def main():
243
# Batched processing - items are automatically grouped
244
items = [f"item_{i}" for i in range(100)]
245
246
# These calls will be automatically batched up to max_batch_size
247
batch_results = []
248
for item in items:
249
result = process_items_batched.remote([item]) # Each call adds to batch
250
batch_results.append(result)
251
252
print(f"Batched processing completed: {len(batch_results)} results")
253
254
# Concurrent processing - items run in parallel
255
concurrent_futures = []
256
for item in items[:10]: # Process first 10 concurrently
257
future = process_item_concurrent.spawn(item)
258
concurrent_futures.append(future)
259
260
# Collect concurrent results
261
concurrent_results = [future.get() for future in concurrent_futures]
262
print(f"Concurrent processing completed: {len(concurrent_results)} results")
263
```
264
265
## Advanced Patterns
266
267
### Stateful Service with Lifecycle Management
268
269
```python
270
import modal
271
272
app = modal.App()
273
274
@app.cls()
275
class MLInferenceService:
276
model_name: str = modal.parameter()
277
cache_size: int = modal.parameter(default=1000)
278
279
@modal.enter()
280
def load_model(self):
281
"""Load model and initialize cache on container start"""
282
print(f"Loading model: {self.model_name}")
283
self.model = download_and_load_model(self.model_name)
284
self.prediction_cache = LRUCache(maxsize=self.cache_size)
285
self.stats = {"requests": 0, "cache_hits": 0}
286
print("Model loaded and ready for inference")
287
288
@modal.exit()
289
def save_stats(self):
290
"""Save statistics before container shutdown"""
291
print(f"Final stats: {self.stats}")
292
save_stats_to_database(self.stats)
293
294
@modal.method()
295
@modal.batched(max_batch_size=32)
296
def predict_batch(self, inputs: list[str]) -> list[dict]:
297
"""Batched prediction with caching"""
298
results = []
299
uncached_inputs = []
300
uncached_indices = []
301
302
# Check cache for each input
303
for i, inp in enumerate(inputs):
304
if inp in self.prediction_cache:
305
results.append(self.prediction_cache[inp])
306
self.stats["cache_hits"] += 1
307
else:
308
results.append(None) # Placeholder
309
uncached_inputs.append(inp)
310
uncached_indices.append(i)
311
312
# Batch process uncached inputs
313
if uncached_inputs:
314
batch_predictions = self.model.predict(uncached_inputs)
315
for idx, prediction in zip(uncached_indices, batch_predictions):
316
self.prediction_cache[inputs[idx]] = prediction
317
results[idx] = prediction
318
319
self.stats["requests"] += len(inputs)
320
return results
321
322
@modal.method()
323
def get_stats(self) -> dict:
324
"""Get current service statistics"""
325
return self.stats.copy()
326
327
# Usage
328
@app.local_entrypoint()
329
def main():
330
# Create service instance
331
ml_service = MLInferenceService(model_name="bert-base-uncased", cache_size=500)
332
333
# Make predictions (automatically batched)
334
test_inputs = [f"test sentence {i}" for i in range(100)]
335
predictions = ml_service.predict_batch.remote(test_inputs)
336
337
# Check service statistics
338
stats = ml_service.get_stats.remote()
339
print(f"Service stats: {stats}")
340
341
# Make some repeated predictions to test caching
342
repeat_predictions = ml_service.predict_batch.remote(test_inputs[:10])
343
final_stats = ml_service.get_stats.remote()
344
print(f"Final stats with cache hits: {final_stats}")
345
```
346
347
### Concurrent Task Processing with Shared State
348
349
```python
350
import modal
351
352
app = modal.App()
353
354
@app.cls()
355
class TaskProcessor:
356
max_workers: int = modal.parameter(default=10)
357
358
@modal.enter()
359
def setup_processor(self):
360
"""Initialize shared resources"""
361
self.task_queue = initialize_task_queue()
362
self.result_store = initialize_result_store()
363
self.worker_stats = {}
364
365
@modal.method()
366
@modal.concurrent
367
def process_task_concurrent(self, task_id: str, worker_id: str) -> dict:
368
"""Process individual tasks concurrently"""
369
# Track worker statistics
370
if worker_id not in self.worker_stats:
371
self.worker_stats[worker_id] = {"processed": 0, "errors": 0}
372
373
try:
374
# Process the task
375
task_data = self.task_queue.get_task(task_id)
376
result = expensive_task_processing(task_data)
377
378
# Store result
379
self.result_store.put(task_id, result)
380
self.worker_stats[worker_id]["processed"] += 1
381
382
return {"status": "success", "task_id": task_id, "worker": worker_id}
383
384
except Exception as e:
385
self.worker_stats[worker_id]["errors"] += 1
386
return {"status": "error", "task_id": task_id, "error": str(e)}
387
388
@modal.method()
389
def get_worker_stats(self) -> dict:
390
"""Get statistics for all workers"""
391
return self.worker_stats.copy()
392
393
@app.local_entrypoint()
394
def main():
395
processor = TaskProcessor(max_workers=20)
396
397
# Process many tasks concurrently
398
task_ids = [f"task_{i}" for i in range(100)]
399
futures = []
400
401
for i, task_id in enumerate(task_ids):
402
worker_id = f"worker_{i % 20}" # Distribute across workers
403
future = processor.process_task_concurrent.spawn(task_id, worker_id)
404
futures.append(future)
405
406
# Collect results
407
results = [future.get() for future in futures]
408
409
# Check worker statistics
410
stats = processor.get_worker_stats.remote()
411
print(f"Worker statistics: {stats}")
412
413
# Analyze results
414
successful = sum(1 for r in results if r["status"] == "success")
415
errors = sum(1 for r in results if r["status"] == "error")
416
print(f"Processed {successful} tasks successfully, {errors} errors")
417
```