0
# Middleware System
1
2
Starlette provides a powerful middleware system for processing requests and responses in a composable pipeline. Middleware can handle cross-cutting concerns like CORS, compression, authentication, error handling, and custom request/response processing.
3
4
## Middleware Base Classes
5
6
### Middleware Configuration
7
8
```python { .api }
9
from starlette.middleware import Middleware
10
from typing import Any, Sequence
11
12
class Middleware:
13
"""
14
Middleware configuration wrapper.
15
16
Encapsulates middleware class with its initialization arguments
17
for use in application middleware stack configuration.
18
"""
19
20
def __init__(self, cls: type, *args: Any, **kwargs: Any) -> None:
21
"""
22
Initialize middleware configuration.
23
24
Args:
25
cls: Middleware class
26
*args: Positional arguments for middleware constructor
27
**kwargs: Keyword arguments for middleware constructor
28
"""
29
self.cls = cls
30
self.args = args
31
self.kwargs = kwargs
32
```
33
34
### Base HTTP Middleware
35
36
```python { .api }
37
from starlette.middleware.base import BaseHTTPMiddleware
38
from starlette.requests import Request
39
from starlette.responses import Response
40
from starlette.types import ASGIApp, Scope, Receive, Send
41
from typing import Callable, Awaitable
42
43
class BaseHTTPMiddleware:
44
"""
45
Base class for HTTP middleware.
46
47
Provides a simplified interface for creating middleware that processes
48
HTTP requests and responses. Handles ASGI protocol details and provides
49
a clean dispatch method to override.
50
"""
51
52
def __init__(self, app: ASGIApp, dispatch: Callable | None = None) -> None:
53
"""
54
Initialize HTTP middleware.
55
56
Args:
57
app: Next ASGI application in the stack
58
dispatch: Optional custom dispatch function
59
"""
60
self.app = app
61
if dispatch is not None:
62
self.dispatch_func = dispatch
63
64
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
65
"""ASGI application interface."""
66
if scope["type"] != "http":
67
await self.app(scope, receive, send)
68
return
69
70
# Process HTTP request through dispatch
71
async def call_next(request: Request) -> Response:
72
# ... implementation details
73
74
request = Request(scope, receive)
75
response = await self.dispatch(request, call_next)
76
await response(scope, receive, send)
77
78
async def dispatch(
79
self,
80
request: Request,
81
call_next: Callable[[Request], Awaitable[Response]]
82
) -> Response:
83
"""
84
Process HTTP request and response.
85
86
Override this method to implement middleware logic.
87
88
Args:
89
request: HTTP request object
90
call_next: Function to call next middleware/endpoint
91
92
Returns:
93
Response: HTTP response object
94
"""
95
# Default implementation just calls next middleware
96
response = await call_next(request)
97
return response
98
```
99
100
## Built-in Middleware
101
102
### CORS Middleware
103
104
```python { .api }
105
from starlette.middleware.cors import CORSMiddleware
106
from typing import Sequence
107
108
class CORSMiddleware:
109
"""
110
Cross-Origin Resource Sharing (CORS) middleware.
111
112
Handles CORS headers for cross-origin requests, including:
113
- Preflight request handling
114
- Origin validation
115
- Credential support
116
- Custom headers and methods
117
"""
118
119
def __init__(
120
self,
121
app: ASGIApp,
122
allow_origins: Sequence[str] = (),
123
allow_methods: Sequence[str] = ("GET",),
124
allow_headers: Sequence[str] = (),
125
allow_credentials: bool = False,
126
allow_origin_regex: str | None = None,
127
expose_headers: Sequence[str] = (),
128
max_age: int = 600,
129
) -> None:
130
"""
131
Initialize CORS middleware.
132
133
Args:
134
app: ASGI application
135
allow_origins: Allowed origin URLs ("*" for all)
136
allow_methods: Allowed HTTP methods
137
allow_headers: Allowed request headers
138
allow_credentials: Allow credentials (cookies, auth headers)
139
allow_origin_regex: Regex pattern for allowed origins
140
expose_headers: Headers to expose to client
141
max_age: Preflight cache duration in seconds
142
"""
143
144
def is_allowed_origin(self, origin: str) -> bool:
145
"""Check if origin is allowed."""
146
147
def preflight_response(self, request_headers: Headers) -> Response:
148
"""Handle CORS preflight request."""
149
150
async def simple_response(
151
self,
152
scope: Scope,
153
receive: Receive,
154
send: Send,
155
request_headers: Headers
156
) -> None:
157
"""Handle simple CORS request."""
158
159
# CORS constants
160
ALL_METHODS = (
161
"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"
162
)
163
164
SAFELISTED_HEADERS = {
165
"accept", "accept-language", "content-language", "content-type"
166
}
167
```
168
169
### GZip Compression Middleware
170
171
```python { .api }
172
from starlette.middleware.gzip import GZipMiddleware
173
174
class GZipMiddleware:
175
"""
176
GZip compression middleware.
177
178
Automatically compresses response bodies when:
179
- Client accepts gzip encoding
180
- Response size exceeds minimum threshold
181
- Content-Type is compressible
182
"""
183
184
def __init__(
185
self,
186
app: ASGIApp,
187
minimum_size: int = 500,
188
compresslevel: int = 9,
189
) -> None:
190
"""
191
Initialize GZip middleware.
192
193
Args:
194
app: ASGI application
195
minimum_size: Minimum response size to compress (bytes)
196
compresslevel: Compression level (1-9, higher = better compression)
197
"""
198
```
199
200
### Session Middleware
201
202
```python { .api }
203
from starlette.middleware.sessions import SessionMiddleware
204
205
class SessionMiddleware:
206
"""
207
Cookie-based session middleware.
208
209
Provides encrypted session storage using signed cookies.
210
Sessions are available as request.session dictionary.
211
"""
212
213
def __init__(
214
self,
215
app: ASGIApp,
216
secret_key: str,
217
session_cookie: str = "session",
218
max_age: int = 14 * 24 * 60 * 60, # 14 days
219
path: str = "/",
220
same_site: str = "lax",
221
https_only: bool = False,
222
domain: str | None = None,
223
) -> None:
224
"""
225
Initialize session middleware.
226
227
Args:
228
app: ASGI application
229
secret_key: Secret key for signing cookies (keep secret!)
230
session_cookie: Cookie name for session data
231
max_age: Session lifetime in seconds
232
path: Cookie path
233
same_site: SameSite cookie policy
234
https_only: Only send cookies over HTTPS
235
domain: Cookie domain
236
"""
237
```
238
239
### Authentication Middleware
240
241
```python { .api }
242
from starlette.middleware.authentication import AuthenticationMiddleware
243
from starlette.authentication import AuthenticationBackend
244
from starlette.requests import Request
245
from starlette.responses import Response
246
247
class AuthenticationMiddleware:
248
"""
249
Authentication middleware using pluggable backends.
250
251
Authenticates requests and populates request.user and request.auth
252
based on configured authentication backends.
253
"""
254
255
def __init__(
256
self,
257
app: ASGIApp,
258
backend: AuthenticationBackend,
259
on_error: Callable[[Request, Exception], Response] | None = None,
260
) -> None:
261
"""
262
Initialize authentication middleware.
263
264
Args:
265
app: ASGI application
266
backend: Authentication backend implementation
267
on_error: Custom error handler for authentication failures
268
"""
269
270
def default_on_error(self, conn: Request, exc: Exception) -> Response:
271
"""Default authentication error handler."""
272
```
273
274
### Error Handling Middleware
275
276
```python { .api }
277
from starlette.middleware.errors import ServerErrorMiddleware
278
from starlette.requests import Request
279
from starlette.responses import Response, PlainTextResponse, HTMLResponse
280
281
class ServerErrorMiddleware:
282
"""
283
Server error handling middleware.
284
285
Catches unhandled exceptions and returns appropriate error responses.
286
In debug mode, provides detailed error pages with stack traces.
287
"""
288
289
def __init__(
290
self,
291
app: ASGIApp,
292
handler: Callable[[Request, Exception], Response] | None = None,
293
debug: bool = False,
294
) -> None:
295
"""
296
Initialize server error middleware.
297
298
Args:
299
app: ASGI application
300
handler: Custom error handler function
301
debug: Enable debug mode with detailed error pages
302
"""
303
304
def debug_response(self, request: Request, exc: Exception) -> Response:
305
"""Generate debug error response with stack trace."""
306
307
def error_response(self, request: Request, exc: Exception) -> Response:
308
"""Generate production error response."""
309
```
310
311
### Exception Middleware
312
313
```python { .api }
314
from starlette.middleware.exceptions import ExceptionMiddleware
315
from starlette.exceptions import HTTPException
316
317
class ExceptionMiddleware:
318
"""
319
Exception handling middleware for HTTP and WebSocket exceptions.
320
321
Provides centralized exception handling with support for:
322
- HTTP exceptions (HTTPException)
323
- WebSocket exceptions (WebSocketException)
324
- Custom exception handlers
325
- Status code handlers
326
"""
327
328
def __init__(
329
self,
330
app: ASGIApp,
331
handlers: dict[Any, Callable] | None = None,
332
debug: bool = False,
333
) -> None:
334
"""
335
Initialize exception middleware.
336
337
Args:
338
app: ASGI application
339
handlers: Exception handler mapping
340
debug: Enable debug mode
341
"""
342
343
def add_exception_handler(
344
self,
345
exc_class_or_status_code: type[Exception] | int,
346
handler: Callable,
347
) -> None:
348
"""Add exception handler for specific exception or status code."""
349
350
async def http_exception(self, request: Request, exc: HTTPException) -> Response:
351
"""Handle HTTP exceptions."""
352
353
async def websocket_exception(self, websocket: WebSocket, exc: WebSocketException) -> None:
354
"""Handle WebSocket exceptions."""
355
```
356
357
### Security Middleware
358
359
```python { .api }
360
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
361
from starlette.middleware.trustedhost import TrustedHostMiddleware
362
363
class HTTPSRedirectMiddleware:
364
"""
365
HTTPS redirect middleware.
366
367
Automatically redirects HTTP requests to HTTPS for security.
368
"""
369
370
def __init__(self, app: ASGIApp) -> None:
371
"""Initialize HTTPS redirect middleware."""
372
373
class TrustedHostMiddleware:
374
"""
375
Trusted host validation middleware.
376
377
Validates Host header against allowed hosts to prevent
378
Host header injection attacks.
379
"""
380
381
def __init__(
382
self,
383
app: ASGIApp,
384
allowed_hosts: Sequence[str] | None = None,
385
www_redirect: bool = True,
386
) -> None:
387
"""
388
Initialize trusted host middleware.
389
390
Args:
391
app: ASGI application
392
allowed_hosts: List of allowed hostnames (* for wildcard)
393
www_redirect: Redirect www subdomain to apex domain
394
"""
395
```
396
397
## Middleware Configuration
398
399
### Application-level Middleware
400
401
```python { .api }
402
from starlette.applications import Starlette
403
from starlette.middleware import Middleware
404
from starlette.middleware.cors import CORSMiddleware
405
from starlette.middleware.gzip import GZipMiddleware
406
from starlette.middleware.sessions import SessionMiddleware
407
408
# Configure middleware stack
409
middleware = [
410
# Order matters - first in list wraps all others
411
Middleware(CORSMiddleware,
412
allow_origins=["https://example.com"],
413
allow_methods=["GET", "POST"],
414
allow_headers=["*"]),
415
Middleware(GZipMiddleware, minimum_size=1000),
416
Middleware(SessionMiddleware, secret_key="secret-key-here"),
417
]
418
419
app = Starlette(
420
routes=routes,
421
middleware=middleware,
422
)
423
424
# Add middleware after initialization
425
app.add_middleware(HTTPSRedirectMiddleware)
426
app.add_middleware(TrustedHostMiddleware,
427
allowed_hosts=["example.com", "*.example.com"])
428
```
429
430
### Route-specific Middleware
431
432
```python { .api }
433
from starlette.routing import Route
434
from starlette.middleware import Middleware
435
436
# Middleware for specific routes
437
class TimingMiddleware(BaseHTTPMiddleware):
438
async def dispatch(self, request, call_next):
439
start_time = time.time()
440
response = await call_next(request)
441
process_time = time.time() - start_time
442
response.headers["X-Process-Time"] = str(process_time)
443
return response
444
445
routes = [
446
Route("/", homepage), # No extra middleware
447
Route("/api/data", api_endpoint, middleware=[
448
Middleware(TimingMiddleware),
449
Middleware(CacheMiddleware, ttl=300),
450
]),
451
]
452
```
453
454
## Custom Middleware
455
456
### Simple Function Middleware
457
458
```python { .api }
459
from starlette.types import ASGIApp, Scope, Receive, Send
460
461
def logging_middleware(app: ASGIApp) -> ASGIApp:
462
"""Simple ASGI middleware function."""
463
464
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
465
if scope["type"] == "http":
466
print(f"Request: {scope['method']} {scope['path']}")
467
468
await app(scope, receive, send)
469
470
return middleware
471
472
# Usage
473
app = Starlette(routes=routes)
474
app.add_middleware(logging_middleware)
475
```
476
477
### Class-based Middleware
478
479
```python { .api }
480
import time
481
from starlette.middleware.base import BaseHTTPMiddleware
482
from starlette.requests import Request
483
from starlette.responses import Response
484
485
class TimingMiddleware(BaseHTTPMiddleware):
486
"""Middleware to add timing headers to responses."""
487
488
async def dispatch(self, request: Request, call_next) -> Response:
489
start_time = time.time()
490
491
# Process request
492
response = await call_next(request)
493
494
# Add timing information
495
process_time = time.time() - start_time
496
response.headers["X-Process-Time"] = str(process_time)
497
498
return response
499
500
class RateLimitMiddleware(BaseHTTPMiddleware):
501
"""Simple rate limiting middleware."""
502
503
def __init__(self, app, calls_per_minute: int = 60):
504
super().__init__(app)
505
self.calls_per_minute = calls_per_minute
506
self.clients = {} # In production, use Redis or similar
507
508
async def dispatch(self, request: Request, call_next) -> Response:
509
client_ip = request.client.host if request.client else "unknown"
510
now = time.time()
511
512
# Clean old entries
513
minute_ago = now - 60
514
self.clients = {
515
ip: times for ip, times in self.clients.items()
516
if any(t > minute_ago for t in times)
517
}
518
519
# Check rate limit
520
if client_ip not in self.clients:
521
self.clients[client_ip] = []
522
523
recent_calls = [t for t in self.clients[client_ip] if t > minute_ago]
524
525
if len(recent_calls) >= self.calls_per_minute:
526
return JSONResponse(
527
{"error": "Rate limit exceeded"},
528
status_code=429,
529
headers={"Retry-After": "60"}
530
)
531
532
# Add current call
533
self.clients[client_ip].append(now)
534
535
# Process request
536
response = await call_next(request)
537
response.headers["X-RateLimit-Remaining"] = str(
538
self.calls_per_minute - len(recent_calls) - 1
539
)
540
541
return response
542
543
class AuthenticationMiddleware(BaseHTTPMiddleware):
544
"""Custom authentication middleware."""
545
546
async def dispatch(self, request: Request, call_next) -> Response:
547
# Skip authentication for public endpoints
548
if request.url.path in ["/", "/health", "/docs"]:
549
return await call_next(request)
550
551
# Check for API key
552
api_key = request.headers.get("X-API-Key")
553
if not api_key:
554
return JSONResponse(
555
{"error": "API key required"},
556
status_code=401,
557
headers={"WWW-Authenticate": "ApiKey"}
558
)
559
560
# Validate API key
561
user = await validate_api_key(api_key)
562
if not user:
563
return JSONResponse(
564
{"error": "Invalid API key"},
565
status_code=401
566
)
567
568
# Add user to request state
569
request.state.user = user
570
571
return await call_next(request)
572
573
async def validate_api_key(api_key: str):
574
# Implement your API key validation logic
575
# Return user object or None
576
pass
577
```
578
579
### ASGI Middleware
580
581
```python { .api }
582
class RequestIDMiddleware:
583
"""Add unique request ID to each request."""
584
585
def __init__(self, app: ASGIApp):
586
self.app = app
587
588
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
589
if scope["type"] == "http":
590
# Generate unique request ID
591
request_id = str(uuid.uuid4())
592
593
# Add to scope
594
scope["request_id"] = request_id
595
596
# Wrap send to add header to response
597
async def send_wrapper(message):
598
if message["type"] == "http.response.start":
599
headers = list(message.get("headers", []))
600
headers.append([b"x-request-id", request_id.encode()])
601
message = {**message, "headers": headers}
602
await send(message)
603
604
await self.app(scope, receive, send_wrapper)
605
else:
606
await self.app(scope, receive, send)
607
608
class CompressionMiddleware:
609
"""Custom compression middleware with multiple algorithms."""
610
611
def __init__(self, app: ASGIApp, algorithms: list[str] = None):
612
self.app = app
613
self.algorithms = algorithms or ["gzip", "br"]
614
615
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
616
if scope["type"] != "http":
617
await self.app(scope, receive, send)
618
return
619
620
# Check if client accepts compression
621
headers = dict(scope.get("headers", []))
622
accept_encoding = headers.get(b"accept-encoding", b"").decode()
623
624
selected_encoding = None
625
for encoding in self.algorithms:
626
if encoding in accept_encoding:
627
selected_encoding = encoding
628
break
629
630
if not selected_encoding:
631
await self.app(scope, receive, send)
632
return
633
634
# Wrap response to compress body
635
response_complete = False
636
637
async def send_wrapper(message):
638
nonlocal response_complete
639
640
if message["type"] == "http.response.start":
641
# Add compression header
642
headers = list(message.get("headers", []))
643
headers.append([b"content-encoding", selected_encoding.encode()])
644
message = {**message, "headers": headers}
645
646
elif message["type"] == "http.response.body":
647
if not response_complete:
648
body = message.get("body", b"")
649
if body:
650
# Compress body based on selected encoding
651
if selected_encoding == "gzip":
652
body = gzip.compress(body)
653
elif selected_encoding == "br":
654
body = brotli.compress(body)
655
656
message = {**message, "body": body}
657
658
await send(message)
659
660
await self.app(scope, receive, send_wrapper)
661
```
662
663
## Middleware Best Practices
664
665
### Error Handling in Middleware
666
667
```python { .api }
668
class SafeMiddleware(BaseHTTPMiddleware):
669
"""Middleware with proper error handling."""
670
671
async def dispatch(self, request: Request, call_next) -> Response:
672
try:
673
# Pre-processing
674
await self.before_request(request)
675
676
# Call next middleware/endpoint
677
response = await call_next(request)
678
679
# Post-processing
680
await self.after_request(request, response)
681
682
return response
683
684
except Exception as e:
685
# Handle middleware errors
686
return await self.handle_error(request, e)
687
688
async def before_request(self, request: Request):
689
"""Override for pre-processing logic."""
690
pass
691
692
async def after_request(self, request: Request, response: Response):
693
"""Override for post-processing logic."""
694
pass
695
696
async def handle_error(self, request: Request, exc: Exception) -> Response:
697
"""Override for error handling."""
698
print(f"Middleware error: {exc}")
699
return JSONResponse(
700
{"error": "Internal server error"},
701
status_code=500
702
)
703
```
704
705
### Conditional Middleware
706
707
```python { .api }
708
class ConditionalMiddleware(BaseHTTPMiddleware):
709
"""Middleware that applies conditionally."""
710
711
def __init__(self, app, condition_func: Callable[[Request], bool]):
712
super().__init__(app)
713
self.should_apply = condition_func
714
715
async def dispatch(self, request: Request, call_next) -> Response:
716
# Check if middleware should apply
717
if not self.should_apply(request):
718
return await call_next(request)
719
720
# Apply middleware logic
721
return await self.process_request(request, call_next)
722
723
async def process_request(self, request: Request, call_next) -> Response:
724
# Override this method
725
return await call_next(request)
726
727
# Usage examples
728
def is_api_request(request: Request) -> bool:
729
return request.url.path.startswith("/api/")
730
731
def is_authenticated(request: Request) -> bool:
732
return "authorization" in request.headers
733
734
# Apply different middleware based on conditions
735
app.add_middleware(
736
ConditionalMiddleware,
737
condition_func=is_api_request
738
)
739
```
740
741
### Middleware Dependencies
742
743
```python { .api }
744
class DatabaseMiddleware(BaseHTTPMiddleware):
745
"""Provide database connection to requests."""
746
747
async def dispatch(self, request: Request, call_next) -> Response:
748
# Create database connection
749
async with database.connection() as conn:
750
# Add to request state
751
request.state.db = conn
752
753
# Process request with database available
754
response = await call_next(request)
755
756
return response
757
758
class CacheMiddleware(BaseHTTPMiddleware):
759
"""Cache responses (requires database middleware)."""
760
761
async def dispatch(self, request: Request, call_next) -> Response:
762
# This middleware depends on DatabaseMiddleware
763
if not hasattr(request.state, 'db'):
764
raise RuntimeError("CacheMiddleware requires DatabaseMiddleware")
765
766
cache_key = f"response:{request.url}"
767
768
# Try cache first
769
cached = await request.state.db.get_cache(cache_key)
770
if cached:
771
return JSONResponse(cached)
772
773
# Generate response
774
response = await call_next(request)
775
776
# Cache response if successful
777
if response.status_code == 200:
778
await request.state.db.set_cache(cache_key, response.body)
779
780
return response
781
782
# Order matters - DatabaseMiddleware must come first
783
middleware = [
784
Middleware(DatabaseMiddleware),
785
Middleware(CacheMiddleware),
786
]
787
```
788
789
Starlette's middleware system provides a powerful way to implement cross-cutting concerns with a clean, composable architecture that supports both simple and complex use cases.