0
# Middleware System
1
2
Middleware components for request/response processing including authentication, CORS, compression, rate limiting, and custom middleware. Litestar supports both function-based and class-based middleware with a flexible pipeline architecture.
3
4
## Capabilities
5
6
### Middleware Base Classes
7
8
Base classes for implementing custom middleware components.
9
10
```python { .api }
11
class AbstractMiddleware:
12
def __init__(self, app: ASGIApp):
13
"""
14
Base middleware class.
15
16
Parameters:
17
- app: ASGI application to wrap
18
"""
19
self.app = app
20
21
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
22
"""
23
ASGI application interface.
24
25
Parameters:
26
- scope: ASGI scope
27
- receive: ASGI receive callable
28
- send: ASGI send callable
29
"""
30
await self.app(scope, receive, send)
31
32
class ASGIMiddleware(AbstractMiddleware):
33
def __init__(self, app: ASGIApp, middleware: type[AbstractMiddleware], **kwargs: Any):
34
"""
35
Wrapper for ASGI middleware.
36
37
Parameters:
38
- app: ASGI application
39
- middleware: Middleware class to instantiate
40
- **kwargs: Arguments for middleware constructor
41
"""
42
```
43
44
### Middleware Configuration
45
46
Configuration classes for defining and organizing middleware.
47
48
```python { .api }
49
class DefineMiddleware:
50
def __init__(
51
self,
52
middleware: type[AbstractMiddleware] | Callable[..., ASGIApp],
53
*args: Any,
54
**kwargs: Any,
55
):
56
"""
57
Configure middleware with arguments.
58
59
Parameters:
60
- middleware: Middleware class or factory function
61
- *args: Positional arguments for middleware
62
- **kwargs: Keyword arguments for middleware
63
"""
64
65
@property
66
def middleware(self) -> type[AbstractMiddleware] | Callable[..., ASGIApp]:
67
"""Get the middleware class or factory."""
68
69
@property
70
def args(self) -> tuple[Any, ...]:
71
"""Get positional arguments."""
72
73
@property
74
def kwargs(self) -> dict[str, Any]:
75
"""Get keyword arguments."""
76
77
def middleware(
78
middleware_class: type[AbstractMiddleware],
79
*args: Any,
80
**kwargs: Any,
81
) -> Callable[[T], T]:
82
"""
83
Decorator to apply middleware to route handlers.
84
85
Parameters:
86
- middleware_class: Middleware class to apply
87
- *args: Positional arguments for middleware
88
- **kwargs: Keyword arguments for middleware
89
90
Returns:
91
Decorator function
92
"""
93
```
94
95
### Built-in Middleware
96
97
Pre-built middleware components for common use cases.
98
99
```python { .api }
100
class CORSMiddleware(AbstractMiddleware):
101
def __init__(
102
self,
103
app: ASGIApp,
104
*,
105
allow_origins: Sequence[str] = ("*",),
106
allow_methods: Sequence[str] = ("GET",),
107
allow_headers: Sequence[str] = (),
108
allow_credentials: bool = False,
109
allow_origin_regex: str | None = None,
110
expose_headers: Sequence[str] = (),
111
max_age: int = 600,
112
):
113
"""
114
CORS (Cross-Origin Resource Sharing) middleware.
115
116
Parameters:
117
- app: ASGI application
118
- allow_origins: Allowed origin patterns
119
- allow_methods: Allowed HTTP methods
120
- allow_headers: Allowed request headers
121
- allow_credentials: Allow credentials in requests
122
- allow_origin_regex: Regex pattern for allowed origins
123
- expose_headers: Headers to expose to client
124
- max_age: Preflight cache duration in seconds
125
"""
126
127
class CompressionMiddleware(AbstractMiddleware):
128
def __init__(
129
self,
130
app: ASGIApp,
131
*,
132
backend: Literal["gzip", "brotli"] = "gzip",
133
minimum_size: int = 500,
134
exclude_opt_key: str = "skip_compression",
135
compression_facade: CompressionFacade | None = None,
136
):
137
"""
138
Response compression middleware.
139
140
Parameters:
141
- app: ASGI application
142
- backend: Compression algorithm
143
- minimum_size: Minimum response size to compress
144
- exclude_opt_key: Key to exclude routes from compression
145
- compression_facade: Custom compression implementation
146
"""
147
148
class SessionMiddleware(AbstractMiddleware):
149
def __init__(
150
self,
151
app: ASGIApp,
152
store: Store,
153
*,
154
secret: str | Secret,
155
key: str = "session",
156
max_age: int = 1209600, # 14 days
157
path: str = "/",
158
domain: str | None = None,
159
secure: bool = False,
160
httponly: bool = True,
161
samesite: Literal["lax", "strict", "none"] = "lax",
162
exclude: str | list[str] | None = None,
163
exclude_opt_key: str = "skip_session",
164
):
165
"""
166
Session management middleware.
167
168
Parameters:
169
- app: ASGI application
170
- store: Session storage backend
171
- secret: Secret key for session signing
172
- key: Session cookie name
173
- max_age: Session lifetime in seconds
174
- path: Cookie path
175
- domain: Cookie domain
176
- secure: Use secure cookies
177
- httponly: Use HTTP-only cookies
178
- samesite: SameSite cookie attribute
179
- exclude: Paths to exclude from session handling
180
- exclude_opt_key: Route option key for exclusion
181
"""
182
183
class TrustedHostMiddleware(AbstractMiddleware):
184
def __init__(
185
self,
186
app: ASGIApp,
187
allowed_hosts: Sequence[str],
188
*,
189
exclude_opt_key: str = "skip_allowed_hosts_check",
190
exclude: str | list[str] | None = None,
191
redirect_domains: Sequence[str] | None = None,
192
):
193
"""
194
Trusted host validation middleware.
195
196
Parameters:
197
- app: ASGI application
198
- allowed_hosts: List of allowed host patterns
199
- exclude_opt_key: Route option key for exclusion
200
- exclude: Paths to exclude from host checking
201
- redirect_domains: Domains to redirect to if host is not allowed
202
"""
203
204
class CSRFMiddleware(AbstractMiddleware):
205
def __init__(
206
self,
207
app: ASGIApp,
208
*,
209
secret: str | Secret,
210
cookie_name: str = "csrftoken",
211
cookie_path: str = "/",
212
cookie_domain: str | None = None,
213
cookie_secure: bool = False,
214
cookie_httponly: bool = False,
215
cookie_samesite: Literal["lax", "strict", "none"] = "lax",
216
header_name: str = "x-csrftoken",
217
safe_methods: set[str] | None = None,
218
exclude: str | list[str] | None = None,
219
exclude_opt_key: str = "skip_csrf_protection",
220
):
221
"""
222
CSRF (Cross-Site Request Forgery) protection middleware.
223
224
Parameters:
225
- app: ASGI application
226
- secret: Secret key for token generation
227
- cookie_name: CSRF token cookie name
228
- cookie_path: Cookie path
229
- cookie_domain: Cookie domain
230
- cookie_secure: Use secure cookies
231
- cookie_httponly: Use HTTP-only cookies
232
- cookie_samesite: SameSite cookie attribute
233
- header_name: HTTP header containing CSRF token
234
- safe_methods: HTTP methods that don't require CSRF protection
235
- exclude: Paths to exclude from CSRF protection
236
- exclude_opt_key: Route option key for exclusion
237
"""
238
```
239
240
### Exception Handling Middleware
241
242
Middleware for handling exceptions and generating appropriate error responses.
243
244
```python { .api }
245
class ExceptionHandlerMiddleware(AbstractMiddleware):
246
def __init__(
247
self,
248
app: ASGIApp,
249
*,
250
debug: bool = False,
251
exception_handlers: dict[type[Exception], ExceptionHandler] | None = None,
252
):
253
"""
254
Exception handling middleware.
255
256
Parameters:
257
- app: ASGI application
258
- debug: Enable debug mode with detailed error info
259
- exception_handlers: Custom exception handlers
260
"""
261
262
async def handle_exception(
263
self,
264
request: Request,
265
exc: Exception,
266
) -> Response:
267
"""
268
Handle an exception and generate appropriate response.
269
270
Parameters:
271
- request: HTTP request that caused the exception
272
- exc: Exception that was raised
273
274
Returns:
275
Response object for the exception
276
"""
277
```
278
279
### Rate Limiting Middleware
280
281
Middleware for implementing rate limiting and throttling.
282
283
```python { .api }
284
class RateLimitMiddleware(AbstractMiddleware):
285
def __init__(
286
self,
287
app: ASGIApp,
288
store: Store,
289
*,
290
rate_limit: RateLimit | None = None,
291
exclude: str | list[str] | None = None,
292
exclude_opt_key: str = "skip_rate_limiting",
293
):
294
"""
295
Rate limiting middleware.
296
297
Parameters:
298
- app: ASGI application
299
- store: Storage backend for rate limit data
300
- rate_limit: Default rate limit configuration
301
- exclude: Paths to exclude from rate limiting
302
- exclude_opt_key: Route option key for exclusion
303
"""
304
305
class RateLimit:
306
def __init__(
307
self,
308
rate: int,
309
per: timedelta,
310
*,
311
identifier: Callable[[Request], str] | None = None,
312
key_formatter: Callable[[str, str], str] | None = None,
313
):
314
"""
315
Rate limit configuration.
316
317
Parameters:
318
- rate: Number of requests allowed
319
- per: Time window for rate limiting
320
- identifier: Function to identify clients
321
- key_formatter: Function to format cache keys
322
"""
323
```
324
325
### Custom Middleware Patterns
326
327
Common patterns for implementing custom middleware.
328
329
```python { .api }
330
class ProcessingTimeMiddleware(AbstractMiddleware):
331
"""Example middleware that adds processing time header."""
332
333
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
334
if scope["type"] != "http":
335
await self.app(scope, receive, send)
336
return
337
338
start_time = time.time()
339
340
async def send_wrapper(message: Message) -> None:
341
if message["type"] == "http.response.start":
342
process_time = time.time() - start_time
343
message.setdefault("headers", [])
344
message["headers"].append([
345
b"x-process-time",
346
str(process_time).encode()
347
])
348
await send(message)
349
350
await self.app(scope, receive, send_wrapper)
351
352
class LoggingMiddleware(AbstractMiddleware):
353
"""Example middleware for request/response logging."""
354
355
def __init__(self, app: ASGIApp, logger: Logger | None = None):
356
super().__init__(app)
357
self.logger = logger or logging.getLogger(__name__)
358
359
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
360
if scope["type"] != "http":
361
await self.app(scope, receive, send)
362
return
363
364
request = Request(scope, receive)
365
self.logger.info(f"Request: {request.method} {request.url}")
366
367
start_time = time.time()
368
369
async def send_wrapper(message: Message) -> None:
370
if message["type"] == "http.response.start":
371
duration = time.time() - start_time
372
status_code = message["status"]
373
self.logger.info(
374
f"Response: {status_code} ({duration:.3f}s)"
375
)
376
await send(message)
377
378
await self.app(scope, receive, send_wrapper)
379
```
380
381
### Middleware Protocol
382
383
Protocol definition for middleware implementations.
384
385
```python { .api }
386
class MiddlewareProtocol(Protocol):
387
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
388
"""Initialize middleware with ASGI app."""
389
390
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
391
"""ASGI callable interface."""
392
```
393
394
## Usage Examples
395
396
### Basic Middleware Usage
397
398
```python
399
from litestar import Litestar, get
400
from litestar.middleware import DefineMiddleware
401
from litestar.middleware.cors import CORSConfig
402
from litestar.middleware.compression import CompressionConfig
403
404
@get("/api/data")
405
def get_data() -> dict:
406
return {"message": "Hello World"}
407
408
app = Litestar(
409
route_handlers=[get_data],
410
cors_config=CORSConfig(
411
allow_origins=["https://frontend.example.com"],
412
allow_methods=["GET", "POST"],
413
allow_headers=["content-type"],
414
),
415
compression_config=CompressionConfig(
416
backend="gzip",
417
minimum_size=500,
418
),
419
)
420
```
421
422
### Custom Middleware Implementation
423
424
```python
425
import time
426
from litestar.middleware.base import AbstractMiddleware
427
428
class RequestTimingMiddleware(AbstractMiddleware):
429
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
430
if scope["type"] != "http":
431
await self.app(scope, receive, send)
432
return
433
434
start_time = time.time()
435
436
async def send_wrapper(message: Message) -> None:
437
if message["type"] == "http.response.start":
438
duration = time.time() - start_time
439
headers = message.setdefault("headers", [])
440
headers.append([
441
b"x-response-time",
442
f"{duration:.3f}".encode()
443
])
444
await send(message)
445
446
await self.app(scope, receive, send_wrapper)
447
448
app = Litestar(
449
route_handlers=[get_data],
450
middleware=[RequestTimingMiddleware],
451
)
452
```
453
454
### Route-Specific Middleware
455
456
```python
457
from litestar.middleware.authentication import AbstractAuthenticationMiddleware
458
459
class APIKeyMiddleware(AbstractAuthenticationMiddleware):
460
def __init__(self, app: ASGIApp, api_keys: set[str]):
461
super().__init__(app)
462
self.api_keys = api_keys
463
464
async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
465
api_key = connection.headers.get("X-API-Key")
466
467
if api_key in self.api_keys:
468
return AuthenticationResult(user={"api_key": api_key})
469
470
return AuthenticationResult()
471
472
@get(
473
"/admin/users",
474
middleware=[DefineMiddleware(APIKeyMiddleware, {"secret-key-123"})]
475
)
476
def admin_users() -> list[dict]:
477
return [{"id": 1, "name": "admin"}]
478
```
479
480
### Session Middleware with Custom Store
481
482
```python
483
from litestar.stores.redis import RedisStore
484
from litestar.middleware.session import SessionMiddleware
485
486
# Redis session store
487
redis_store = RedisStore(url="redis://localhost:6379")
488
489
session_middleware = SessionMiddleware(
490
store=redis_store,
491
secret="session-secret-key",
492
key="sessionid",
493
max_age=3600, # 1 hour
494
httponly=True,
495
secure=True,
496
samesite="strict",
497
)
498
499
@get("/profile")
500
def get_profile(request: Request) -> dict:
501
user_id = request.session.get("user_id")
502
if not user_id:
503
raise NotAuthorizedException("Login required")
504
505
return {"user_id": user_id, "session_data": dict(request.session)}
506
507
@post("/login")
508
def login(request: Request, data: dict) -> dict:
509
# Authenticate user (simplified)
510
if data.get("username") == "alice":
511
request.session["user_id"] = 123
512
request.session["username"] = "alice"
513
return {"status": "logged in"}
514
515
raise NotAuthorizedException("Invalid credentials")
516
517
app = Litestar(
518
route_handlers=[get_profile, login],
519
middleware=[session_middleware],
520
)
521
```
522
523
### CORS Configuration
524
525
```python
526
from litestar.middleware.cors import CORSConfig
527
528
cors_config = CORSConfig(
529
allow_origins=["https://app.example.com", "https://admin.example.com"],
530
allow_methods=["GET", "POST", "PUT", "DELETE"],
531
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
532
allow_credentials=True,
533
expose_headers=["X-Total-Count"],
534
max_age=86400, # 24 hours
535
)
536
537
app = Litestar(
538
route_handlers=[...],
539
cors_config=cors_config,
540
)
541
```
542
543
### Middleware with Dependency Injection
544
545
```python
546
from litestar import Dependency
547
from litestar.middleware.base import AbstractMiddleware
548
549
class DatabaseMiddleware(AbstractMiddleware):
550
def __init__(self, app: ASGIApp, db_pool: Any):
551
super().__init__(app)
552
self.db_pool = db_pool
553
554
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
555
if scope["type"] == "http":
556
# Add database connection to scope
557
async with self.db_pool.acquire() as conn:
558
scope["db_connection"] = conn
559
await self.app(scope, receive, send)
560
else:
561
await self.app(scope, receive, send)
562
563
def get_db_connection(scope: dict) -> Any:
564
return scope.get("db_connection")
565
566
@get("/users", dependencies={"db": Dependency(get_db_connection)})
567
async def get_users(db: Any) -> list[dict]:
568
# Use database connection from middleware
569
result = await db.fetch("SELECT * FROM users")
570
return [dict(row) for row in result]
571
```
572
573
### Exception Handling Middleware
574
575
```python
576
from litestar.middleware.exceptions import ExceptionHandlerMiddleware
577
from litestar.exceptions import HTTPException
578
import logging
579
580
logger = logging.getLogger(__name__)
581
582
class CustomExceptionMiddleware(AbstractMiddleware):
583
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
584
try:
585
await self.app(scope, receive, send)
586
except HTTPException:
587
# Let HTTP exceptions pass through
588
raise
589
except Exception as exc:
590
logger.exception("Unhandled exception occurred")
591
592
# Convert to HTTP exception
593
http_exc = HTTPException(
594
detail="Internal server error",
595
status_code=500,
596
extra={"error_id": str(uuid.uuid4())}
597
)
598
599
# Send error response
600
response = http_exc.to_response()
601
await response(scope, receive, send)
602
603
app = Litestar(
604
route_handlers=[...],
605
middleware=[CustomExceptionMiddleware],
606
debug=False,
607
)
608
```
609
610
### Conditional Middleware
611
612
```python
613
class ConditionalMiddleware(AbstractMiddleware):
614
def __init__(self, app: ASGIApp, condition: Callable[[Request], bool]):
615
super().__init__(app)
616
self.condition = condition
617
618
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
619
if scope["type"] != "http":
620
await self.app(scope, receive, send)
621
return
622
623
request = Request(scope, receive)
624
625
if self.condition(request):
626
# Apply middleware logic
627
start_time = time.time()
628
629
async def send_wrapper(message: Message) -> None:
630
if message["type"] == "http.response.start":
631
duration = time.time() - start_time
632
message.setdefault("headers", []).append([
633
b"x-conditional-timing",
634
f"{duration:.3f}".encode()
635
])
636
await send(message)
637
638
await self.app(scope, receive, send_wrapper)
639
else:
640
await self.app(scope, receive, send)
641
642
# Only apply timing to API routes
643
def is_api_route(request: Request) -> bool:
644
return request.url.path.startswith("/api/")
645
646
app = Litestar(
647
route_handlers=[...],
648
middleware=[DefineMiddleware(ConditionalMiddleware, is_api_route)],
649
)
650
```
651
652
## Types
653
654
```python { .api }
655
# ASGI types
656
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
657
Scope = dict[str, Any]
658
Receive = Callable[[], Awaitable[Message]]
659
Send = Callable[[Message], Awaitable[None]]
660
Message = dict[str, Any]
661
662
# Middleware types
663
MiddlewareType = type[AbstractMiddleware] | Callable[..., ASGIApp]
664
Middleware = MiddlewareType | DefineMiddleware
665
666
# Exception handler type
667
ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]]
668
669
# Store interface for session/rate limiting
670
class Store(Protocol):
671
async def get(self, key: str) -> Any | None: ...
672
async def set(self, key: str, value: Any, expires_in: int | None = None) -> None: ...
673
async def delete(self, key: str) -> None: ...
674
675
# Rate limiting types
676
RateLimitIdentifier = Callable[[Request], str]
677
RateLimitKeyFormatter = Callable[[str, str], str]
678
```