The little ASGI library that shines.
Starlette provides a powerful middleware system for processing requests and responses in a composable pipeline. Middleware can handle cross-cutting concerns like CORS, compression, authentication, error handling, and custom request/response processing.
from starlette.middleware import Middleware
from typing import Any, Sequence
class Middleware:
"""
Middleware configuration wrapper.
Encapsulates middleware class with its initialization arguments
for use in application middleware stack configuration.
"""
def __init__(self, cls: type, *args: Any, **kwargs: Any) -> None:
"""
Initialize middleware configuration.
Args:
cls: Middleware class
*args: Positional arguments for middleware constructor
**kwargs: Keyword arguments for middleware constructor
"""
self.cls = cls
self.args = args
self.kwargs = kwargsfrom starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Scope, Receive, Send
from typing import Callable, Awaitable
class BaseHTTPMiddleware:
"""
Base class for HTTP middleware.
Provides a simplified interface for creating middleware that processes
HTTP requests and responses. Handles ASGI protocol details and provides
a clean dispatch method to override.
"""
def __init__(self, app: ASGIApp, dispatch: Callable | None = None) -> None:
"""
Initialize HTTP middleware.
Args:
app: Next ASGI application in the stack
dispatch: Optional custom dispatch function
"""
self.app = app
if dispatch is not None:
self.dispatch_func = dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI application interface."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Process HTTP request through dispatch
async def call_next(request: Request) -> Response:
# ... implementation details
request = Request(scope, receive)
response = await self.dispatch(request, call_next)
await response(scope, receive, send)
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""
Process HTTP request and response.
Override this method to implement middleware logic.
Args:
request: HTTP request object
call_next: Function to call next middleware/endpoint
Returns:
Response: HTTP response object
"""
# Default implementation just calls next middleware
response = await call_next(request)
return responsefrom starlette.middleware.cors import CORSMiddleware
from typing import Sequence
class CORSMiddleware:
"""
Cross-Origin Resource Sharing (CORS) middleware.
Handles CORS headers for cross-origin requests, including:
- Preflight request handling
- Origin validation
- Credential support
- Custom headers and methods
"""
def __init__(
self,
app: ASGIApp,
allow_origins: Sequence[str] = (),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
expose_headers: Sequence[str] = (),
max_age: int = 600,
) -> None:
"""
Initialize CORS middleware.
Args:
app: ASGI application
allow_origins: Allowed origin URLs ("*" for all)
allow_methods: Allowed HTTP methods
allow_headers: Allowed request headers
allow_credentials: Allow credentials (cookies, auth headers)
allow_origin_regex: Regex pattern for allowed origins
expose_headers: Headers to expose to client
max_age: Preflight cache duration in seconds
"""
def is_allowed_origin(self, origin: str) -> bool:
"""Check if origin is allowed."""
def preflight_response(self, request_headers: Headers) -> Response:
"""Handle CORS preflight request."""
async def simple_response(
self,
scope: Scope,
receive: Receive,
send: Send,
request_headers: Headers
) -> None:
"""Handle simple CORS request."""
# CORS constants
ALL_METHODS = (
"DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"
)
SAFELISTED_HEADERS = {
"accept", "accept-language", "content-language", "content-type"
}from starlette.middleware.gzip import GZipMiddleware
class GZipMiddleware:
"""
GZip compression middleware.
Automatically compresses response bodies when:
- Client accepts gzip encoding
- Response size exceeds minimum threshold
- Content-Type is compressible
"""
def __init__(
self,
app: ASGIApp,
minimum_size: int = 500,
compresslevel: int = 9,
) -> None:
"""
Initialize GZip middleware.
Args:
app: ASGI application
minimum_size: Minimum response size to compress (bytes)
compresslevel: Compression level (1-9, higher = better compression)
"""from starlette.middleware.sessions import SessionMiddleware
class SessionMiddleware:
"""
Cookie-based session middleware.
Provides encrypted session storage using signed cookies.
Sessions are available as request.session dictionary.
"""
def __init__(
self,
app: ASGIApp,
secret_key: str,
session_cookie: str = "session",
max_age: int = 14 * 24 * 60 * 60, # 14 days
path: str = "/",
same_site: str = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
"""
Initialize session middleware.
Args:
app: ASGI application
secret_key: Secret key for signing cookies (keep secret!)
session_cookie: Cookie name for session data
max_age: Session lifetime in seconds
path: Cookie path
same_site: SameSite cookie policy
https_only: Only send cookies over HTTPS
domain: Cookie domain
"""from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.authentication import AuthenticationBackend
from starlette.requests import Request
from starlette.responses import Response
class AuthenticationMiddleware:
"""
Authentication middleware using pluggable backends.
Authenticates requests and populates request.user and request.auth
based on configured authentication backends.
"""
def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: Callable[[Request, Exception], Response] | None = None,
) -> None:
"""
Initialize authentication middleware.
Args:
app: ASGI application
backend: Authentication backend implementation
on_error: Custom error handler for authentication failures
"""
def default_on_error(self, conn: Request, exc: Exception) -> Response:
"""Default authentication error handler."""from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse, HTMLResponse
class ServerErrorMiddleware:
"""
Server error handling middleware.
Catches unhandled exceptions and returns appropriate error responses.
In debug mode, provides detailed error pages with stack traces.
"""
def __init__(
self,
app: ASGIApp,
handler: Callable[[Request, Exception], Response] | None = None,
debug: bool = False,
) -> None:
"""
Initialize server error middleware.
Args:
app: ASGI application
handler: Custom error handler function
debug: Enable debug mode with detailed error pages
"""
def debug_response(self, request: Request, exc: Exception) -> Response:
"""Generate debug error response with stack trace."""
def error_response(self, request: Request, exc: Exception) -> Response:
"""Generate production error response."""from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.exceptions import HTTPException
class ExceptionMiddleware:
"""
Exception handling middleware for HTTP and WebSocket exceptions.
Provides centralized exception handling with support for:
- HTTP exceptions (HTTPException)
- WebSocket exceptions (WebSocketException)
- Custom exception handlers
- Status code handlers
"""
def __init__(
self,
app: ASGIApp,
handlers: dict[Any, Callable] | None = None,
debug: bool = False,
) -> None:
"""
Initialize exception middleware.
Args:
app: ASGI application
handlers: Exception handler mapping
debug: Enable debug mode
"""
def add_exception_handler(
self,
exc_class_or_status_code: type[Exception] | int,
handler: Callable,
) -> None:
"""Add exception handler for specific exception or status code."""
async def http_exception(self, request: Request, exc: HTTPException) -> Response:
"""Handle HTTP exceptions."""
async def websocket_exception(self, websocket: WebSocket, exc: WebSocketException) -> None:
"""Handle WebSocket exceptions."""from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
class HTTPSRedirectMiddleware:
"""
HTTPS redirect middleware.
Automatically redirects HTTP requests to HTTPS for security.
"""
def __init__(self, app: ASGIApp) -> None:
"""Initialize HTTPS redirect middleware."""
class TrustedHostMiddleware:
"""
Trusted host validation middleware.
Validates Host header against allowed hosts to prevent
Host header injection attacks.
"""
def __init__(
self,
app: ASGIApp,
allowed_hosts: Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
"""
Initialize trusted host middleware.
Args:
app: ASGI application
allowed_hosts: List of allowed hostnames (* for wildcard)
www_redirect: Redirect www subdomain to apex domain
"""from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.sessions import SessionMiddleware
# Configure middleware stack
middleware = [
# Order matters - first in list wraps all others
Middleware(CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["*"]),
Middleware(GZipMiddleware, minimum_size=1000),
Middleware(SessionMiddleware, secret_key="secret-key-here"),
]
app = Starlette(
routes=routes,
middleware=middleware,
)
# Add middleware after initialization
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com"])from starlette.routing import Route
from starlette.middleware import Middleware
# Middleware for specific routes
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
routes = [
Route("/", homepage), # No extra middleware
Route("/api/data", api_endpoint, middleware=[
Middleware(TimingMiddleware),
Middleware(CacheMiddleware, ttl=300),
]),
]from starlette.types import ASGIApp, Scope, Receive, Send
def logging_middleware(app: ASGIApp) -> ASGIApp:
"""Simple ASGI middleware function."""
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
print(f"Request: {scope['method']} {scope['path']}")
await app(scope, receive, send)
return middleware
# Usage
app = Starlette(routes=routes)
app.add_middleware(logging_middleware)import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class TimingMiddleware(BaseHTTPMiddleware):
"""Middleware to add timing headers to responses."""
async def dispatch(self, request: Request, call_next) -> Response:
start_time = time.time()
# Process request
response = await call_next(request)
# Add timing information
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple rate limiting middleware."""
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.clients = {} # In production, use Redis or similar
async def dispatch(self, request: Request, call_next) -> Response:
client_ip = request.client.host if request.client else "unknown"
now = time.time()
# Clean old entries
minute_ago = now - 60
self.clients = {
ip: times for ip, times in self.clients.items()
if any(t > minute_ago for t in times)
}
# Check rate limit
if client_ip not in self.clients:
self.clients[client_ip] = []
recent_calls = [t for t in self.clients[client_ip] if t > minute_ago]
if len(recent_calls) >= self.calls_per_minute:
return JSONResponse(
{"error": "Rate limit exceeded"},
status_code=429,
headers={"Retry-After": "60"}
)
# Add current call
self.clients[client_ip].append(now)
# Process request
response = await call_next(request)
response.headers["X-RateLimit-Remaining"] = str(
self.calls_per_minute - len(recent_calls) - 1
)
return response
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Custom authentication middleware."""
async def dispatch(self, request: Request, call_next) -> Response:
# Skip authentication for public endpoints
if request.url.path in ["/", "/health", "/docs"]:
return await call_next(request)
# Check for API key
api_key = request.headers.get("X-API-Key")
if not api_key:
return JSONResponse(
{"error": "API key required"},
status_code=401,
headers={"WWW-Authenticate": "ApiKey"}
)
# Validate API key
user = await validate_api_key(api_key)
if not user:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Add user to request state
request.state.user = user
return await call_next(request)
async def validate_api_key(api_key: str):
# Implement your API key validation logic
# Return user object or None
passclass RequestIDMiddleware:
"""Add unique request ID to each request."""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
# Generate unique request ID
request_id = str(uuid.uuid4())
# Add to scope
scope["request_id"] = request_id
# Wrap send to add header to response
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.append([b"x-request-id", request_id.encode()])
message = {**message, "headers": headers}
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
class CompressionMiddleware:
"""Custom compression middleware with multiple algorithms."""
def __init__(self, app: ASGIApp, algorithms: list[str] = None):
self.app = app
self.algorithms = algorithms or ["gzip", "br"]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Check if client accepts compression
headers = dict(scope.get("headers", []))
accept_encoding = headers.get(b"accept-encoding", b"").decode()
selected_encoding = None
for encoding in self.algorithms:
if encoding in accept_encoding:
selected_encoding = encoding
break
if not selected_encoding:
await self.app(scope, receive, send)
return
# Wrap response to compress body
response_complete = False
async def send_wrapper(message):
nonlocal response_complete
if message["type"] == "http.response.start":
# Add compression header
headers = list(message.get("headers", []))
headers.append([b"content-encoding", selected_encoding.encode()])
message = {**message, "headers": headers}
elif message["type"] == "http.response.body":
if not response_complete:
body = message.get("body", b"")
if body:
# Compress body based on selected encoding
if selected_encoding == "gzip":
body = gzip.compress(body)
elif selected_encoding == "br":
body = brotli.compress(body)
message = {**message, "body": body}
await send(message)
await self.app(scope, receive, send_wrapper)class SafeMiddleware(BaseHTTPMiddleware):
"""Middleware with proper error handling."""
async def dispatch(self, request: Request, call_next) -> Response:
try:
# Pre-processing
await self.before_request(request)
# Call next middleware/endpoint
response = await call_next(request)
# Post-processing
await self.after_request(request, response)
return response
except Exception as e:
# Handle middleware errors
return await self.handle_error(request, e)
async def before_request(self, request: Request):
"""Override for pre-processing logic."""
pass
async def after_request(self, request: Request, response: Response):
"""Override for post-processing logic."""
pass
async def handle_error(self, request: Request, exc: Exception) -> Response:
"""Override for error handling."""
print(f"Middleware error: {exc}")
return JSONResponse(
{"error": "Internal server error"},
status_code=500
)class ConditionalMiddleware(BaseHTTPMiddleware):
"""Middleware that applies conditionally."""
def __init__(self, app, condition_func: Callable[[Request], bool]):
super().__init__(app)
self.should_apply = condition_func
async def dispatch(self, request: Request, call_next) -> Response:
# Check if middleware should apply
if not self.should_apply(request):
return await call_next(request)
# Apply middleware logic
return await self.process_request(request, call_next)
async def process_request(self, request: Request, call_next) -> Response:
# Override this method
return await call_next(request)
# Usage examples
def is_api_request(request: Request) -> bool:
return request.url.path.startswith("/api/")
def is_authenticated(request: Request) -> bool:
return "authorization" in request.headers
# Apply different middleware based on conditions
app.add_middleware(
ConditionalMiddleware,
condition_func=is_api_request
)class DatabaseMiddleware(BaseHTTPMiddleware):
"""Provide database connection to requests."""
async def dispatch(self, request: Request, call_next) -> Response:
# Create database connection
async with database.connection() as conn:
# Add to request state
request.state.db = conn
# Process request with database available
response = await call_next(request)
return response
class CacheMiddleware(BaseHTTPMiddleware):
"""Cache responses (requires database middleware)."""
async def dispatch(self, request: Request, call_next) -> Response:
# This middleware depends on DatabaseMiddleware
if not hasattr(request.state, 'db'):
raise RuntimeError("CacheMiddleware requires DatabaseMiddleware")
cache_key = f"response:{request.url}"
# Try cache first
cached = await request.state.db.get_cache(cache_key)
if cached:
return JSONResponse(cached)
# Generate response
response = await call_next(request)
# Cache response if successful
if response.status_code == 200:
await request.state.db.set_cache(cache_key, response.body)
return response
# Order matters - DatabaseMiddleware must come first
middleware = [
Middleware(DatabaseMiddleware),
Middleware(CacheMiddleware),
]Starlette's middleware system provides a powerful way to implement cross-cutting concerns with a clean, composable architecture that supports both simple and complex use cases.
Install with Tessl CLI
npx tessl i tessl/pypi-starlette