or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

compute-resources.mdcontainer-images.mdcore-application-client.mdfunction-decorators-helpers.mdindex.mdinfrastructure-services.mdruntime-utilities.mdscheduling-reliability.mdstorage-data.mdutility-classes.mdweb-api-integration.md

function-decorators-helpers.mddocs/

0

# Function Decorators & Helpers

1

2

Modal provides specialized decorators and helper functions for enhancing function behavior, defining class lifecycle methods, enabling batched execution, and controlling concurrency. These tools allow fine-grained control over how functions execute in the Modal environment.

3

4

## Capabilities

5

6

### Method Decorator

7

8

Decorator for defining methods within Modal classes, enabling stateful serverless computing with shared instance state.

9

10

```python { .api }

11

def method(func: Callable) -> Callable:

12

"""Decorator to define methods within Modal classes"""

13

```

14

15

#### Usage Examples

16

17

```python

18

import modal

19

20

app = modal.App()

21

22

@app.cls()

23

class DataProcessor:

24

def __init__(self, model_path: str):

25

# Constructor runs during instance creation

26

self.model = load_model(model_path)

27

self.cache = {}

28

29

@modal.method()

30

def process_single(self, data: str) -> str:

31

# Method can access instance state

32

if data in self.cache:

33

return self.cache[data]

34

35

result = self.model.predict(data)

36

self.cache[data] = result

37

return result

38

39

@modal.method()

40

def process_batch(self, data_list: list[str]) -> list[str]:

41

# Another method sharing the same instance state

42

return [self.process_single(data) for data in data_list]

43

44

@modal.method()

45

def get_cache_size(self) -> int:

46

return len(self.cache)

47

48

# Usage

49

@app.local_entrypoint()

50

def main():

51

processor = DataProcessor("path/to/model")

52

53

# Call methods on the remote instance

54

result1 = processor.process_single.remote("input1")

55

result2 = processor.process_batch.remote(["input2", "input3"])

56

cache_size = processor.get_cache_size.remote()

57

58

print(f"Results: {result1}, {result2}")

59

print(f"Cache size: {cache_size}")

60

```

61

62

### Parameter Helper

63

64

Helper function for defining class initialization parameters with validation and default values, similar to dataclass fields.

65

66

```python { .api }

67

def parameter(*, default: Any = _no_default, init: bool = True) -> Any:

68

"""Define class initialization parameters with options"""

69

```

70

71

#### Usage Examples

72

73

```python

74

import modal

75

76

app = modal.App()

77

78

@app.cls()

79

class ConfigurableService:

80

# Parameters with type annotations and defaults

81

model_name: str = modal.parameter()

82

batch_size: int = modal.parameter(default=32)

83

temperature: float = modal.parameter(default=0.7)

84

debug_mode: bool = modal.parameter(default=False)

85

86

# Internal field not used in constructor

87

_internal_cache: dict = modal.parameter(init=False)

88

89

def __post_init__(self):

90

# Initialize internal state after parameter injection

91

self._internal_cache = {}

92

print(f"Service initialized with model={self.model_name}, batch_size={self.batch_size}")

93

94

@modal.method()

95

def configure_service(self):

96

# Use parameters in methods

97

if self.debug_mode:

98

print(f"Debug: Processing with temperature={self.temperature}")

99

100

return {

101

"model": self.model_name,

102

"batch_size": self.batch_size,

103

"temperature": self.temperature

104

}

105

106

# Usage with different configurations

107

@app.local_entrypoint()

108

def main():

109

# Create instances with different parameters

110

service1 = ConfigurableService("gpt-4", batch_size=64, debug_mode=True)

111

service2 = ConfigurableService("claude-3", temperature=0.5)

112

113

config1 = service1.configure_service.remote()

114

config2 = service2.configure_service.remote()

115

116

print("Service 1 config:", config1)

117

print("Service 2 config:", config2)

118

```

119

120

### Lifecycle Decorators

121

122

Decorators for defining class lifecycle methods that run during container startup and shutdown.

123

124

```python { .api }

125

def enter(func: Callable) -> Callable:

126

"""Decorator for class enter lifecycle method (runs on container startup)"""

127

128

def exit(func: Callable) -> Callable:

129

"""Decorator for class exit lifecycle method (runs on container shutdown)"""

130

```

131

132

#### Usage Examples

133

134

```python

135

import modal

136

137

app = modal.App()

138

139

@app.cls()

140

class DatabaseService:

141

def __init__(self, connection_string: str):

142

self.connection_string = connection_string

143

self.connection = None

144

self.cache = None

145

146

@modal.enter()

147

def setup_connections(self):

148

"""Run once when container starts"""

149

print("Setting up database connection...")

150

self.connection = create_database_connection(self.connection_string)

151

self.cache = initialize_cache()

152

print("Database service ready!")

153

154

@modal.exit()

155

def cleanup_connections(self):

156

"""Run once when container shuts down"""

157

print("Cleaning up database connections...")

158

if self.connection:

159

self.connection.close()

160

if self.cache:

161

self.cache.clear()

162

print("Cleanup complete!")

163

164

@modal.method()

165

def query_data(self, sql: str) -> list[dict]:

166

# Connection is already established from enter()

167

cursor = self.connection.cursor()

168

cursor.execute(sql)

169

return cursor.fetchall()

170

171

@modal.method()

172

def cached_query(self, sql: str) -> list[dict]:

173

# Use cache initialized in enter()

174

if sql in self.cache:

175

return self.cache[sql]

176

177

result = self.query_data(sql)

178

self.cache[sql] = result

179

return result

180

181

# Usage

182

@app.local_entrypoint()

183

def main():

184

db_service = DatabaseService("postgresql://user:pass@host:5432/db")

185

186

# First call triggers enter() lifecycle

187

results = db_service.query_data.remote("SELECT * FROM users LIMIT 10")

188

189

# Subsequent calls reuse the established connection

190

cached_results = db_service.cached_query.remote("SELECT COUNT(*) FROM users")

191

192

print("Query results:", results)

193

print("Cached results:", cached_results)

194

195

# Container shutdown triggers exit() lifecycle automatically

196

```

197

198

### Execution Control Decorators

199

200

Decorators for controlling how functions execute, including batching and concurrency patterns.

201

202

```python { .api }

203

def batched(max_batch_size: int = 10) -> Callable:

204

"""Decorator to enable batched function calls for improved throughput"""

205

206

def concurrent(func: Callable) -> Callable:

207

"""Decorator to enable concurrent function execution"""

208

```

209

210

#### Usage Examples

211

212

```python

213

import modal

214

215

app = modal.App()

216

217

@app.function()

218

@modal.batched(max_batch_size=50)

219

def process_items_batched(items: list[str]) -> list[str]:

220

"""Process multiple items in a single function call"""

221

print(f"Processing batch of {len(items)} items")

222

223

# Expensive setup that benefits from batching

224

model = load_expensive_model()

225

226

# Process all items in the batch

227

results = []

228

for item in items:

229

result = model.process(item)

230

results.append(result)

231

232

return results

233

234

@app.function()

235

@modal.concurrent

236

def process_item_concurrent(item: str) -> str:

237

"""Process items with concurrent execution"""

238

# Each call can run concurrently with others

239

return expensive_processing(item)

240

241

@app.local_entrypoint()

242

def main():

243

# Batched processing - items are automatically grouped

244

items = [f"item_{i}" for i in range(100)]

245

246

# These calls will be automatically batched up to max_batch_size

247

batch_results = []

248

for item in items:

249

result = process_items_batched.remote([item]) # Each call adds to batch

250

batch_results.append(result)

251

252

print(f"Batched processing completed: {len(batch_results)} results")

253

254

# Concurrent processing - items run in parallel

255

concurrent_futures = []

256

for item in items[:10]: # Process first 10 concurrently

257

future = process_item_concurrent.spawn(item)

258

concurrent_futures.append(future)

259

260

# Collect concurrent results

261

concurrent_results = [future.get() for future in concurrent_futures]

262

print(f"Concurrent processing completed: {len(concurrent_results)} results")

263

```

264

265

## Advanced Patterns

266

267

### Stateful Service with Lifecycle Management

268

269

```python

270

import modal

271

272

app = modal.App()

273

274

@app.cls()

275

class MLInferenceService:

276

model_name: str = modal.parameter()

277

cache_size: int = modal.parameter(default=1000)

278

279

@modal.enter()

280

def load_model(self):

281

"""Load model and initialize cache on container start"""

282

print(f"Loading model: {self.model_name}")

283

self.model = download_and_load_model(self.model_name)

284

self.prediction_cache = LRUCache(maxsize=self.cache_size)

285

self.stats = {"requests": 0, "cache_hits": 0}

286

print("Model loaded and ready for inference")

287

288

@modal.exit()

289

def save_stats(self):

290

"""Save statistics before container shutdown"""

291

print(f"Final stats: {self.stats}")

292

save_stats_to_database(self.stats)

293

294

@modal.method()

295

@modal.batched(max_batch_size=32)

296

def predict_batch(self, inputs: list[str]) -> list[dict]:

297

"""Batched prediction with caching"""

298

results = []

299

uncached_inputs = []

300

uncached_indices = []

301

302

# Check cache for each input

303

for i, inp in enumerate(inputs):

304

if inp in self.prediction_cache:

305

results.append(self.prediction_cache[inp])

306

self.stats["cache_hits"] += 1

307

else:

308

results.append(None) # Placeholder

309

uncached_inputs.append(inp)

310

uncached_indices.append(i)

311

312

# Batch process uncached inputs

313

if uncached_inputs:

314

batch_predictions = self.model.predict(uncached_inputs)

315

for idx, prediction in zip(uncached_indices, batch_predictions):

316

self.prediction_cache[inputs[idx]] = prediction

317

results[idx] = prediction

318

319

self.stats["requests"] += len(inputs)

320

return results

321

322

@modal.method()

323

def get_stats(self) -> dict:

324

"""Get current service statistics"""

325

return self.stats.copy()

326

327

# Usage

328

@app.local_entrypoint()

329

def main():

330

# Create service instance

331

ml_service = MLInferenceService(model_name="bert-base-uncased", cache_size=500)

332

333

# Make predictions (automatically batched)

334

test_inputs = [f"test sentence {i}" for i in range(100)]

335

predictions = ml_service.predict_batch.remote(test_inputs)

336

337

# Check service statistics

338

stats = ml_service.get_stats.remote()

339

print(f"Service stats: {stats}")

340

341

# Make some repeated predictions to test caching

342

repeat_predictions = ml_service.predict_batch.remote(test_inputs[:10])

343

final_stats = ml_service.get_stats.remote()

344

print(f"Final stats with cache hits: {final_stats}")

345

```

346

347

### Concurrent Task Processing with Shared State

348

349

```python

350

import modal

351

352

app = modal.App()

353

354

@app.cls()

355

class TaskProcessor:

356

max_workers: int = modal.parameter(default=10)

357

358

@modal.enter()

359

def setup_processor(self):

360

"""Initialize shared resources"""

361

self.task_queue = initialize_task_queue()

362

self.result_store = initialize_result_store()

363

self.worker_stats = {}

364

365

@modal.method()

366

@modal.concurrent

367

def process_task_concurrent(self, task_id: str, worker_id: str) -> dict:

368

"""Process individual tasks concurrently"""

369

# Track worker statistics

370

if worker_id not in self.worker_stats:

371

self.worker_stats[worker_id] = {"processed": 0, "errors": 0}

372

373

try:

374

# Process the task

375

task_data = self.task_queue.get_task(task_id)

376

result = expensive_task_processing(task_data)

377

378

# Store result

379

self.result_store.put(task_id, result)

380

self.worker_stats[worker_id]["processed"] += 1

381

382

return {"status": "success", "task_id": task_id, "worker": worker_id}

383

384

except Exception as e:

385

self.worker_stats[worker_id]["errors"] += 1

386

return {"status": "error", "task_id": task_id, "error": str(e)}

387

388

@modal.method()

389

def get_worker_stats(self) -> dict:

390

"""Get statistics for all workers"""

391

return self.worker_stats.copy()

392

393

@app.local_entrypoint()

394

def main():

395

processor = TaskProcessor(max_workers=20)

396

397

# Process many tasks concurrently

398

task_ids = [f"task_{i}" for i in range(100)]

399

futures = []

400

401

for i, task_id in enumerate(task_ids):

402

worker_id = f"worker_{i % 20}" # Distribute across workers

403

future = processor.process_task_concurrent.spawn(task_id, worker_id)

404

futures.append(future)

405

406

# Collect results

407

results = [future.get() for future in futures]

408

409

# Check worker statistics

410

stats = processor.get_worker_stats.remote()

411

print(f"Worker statistics: {stats}")

412

413

# Analyze results

414

successful = sum(1 for r in results if r["status"] == "success")

415

errors = sum(1 for r in results if r["status"] == "error")

416

print(f"Processed {successful} tasks successfully, {errors} errors")

417

```