CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-starlette

The little ASGI library that shines.

Overview
Eval results
Files

middleware.mddocs/

Middleware System

Starlette provides a powerful middleware system for processing requests and responses in a composable pipeline. Middleware can handle cross-cutting concerns like CORS, compression, authentication, error handling, and custom request/response processing.

Middleware Base Classes

Middleware Configuration

from starlette.middleware import Middleware
from typing import Any, Sequence

class Middleware:
    """
    Middleware configuration wrapper.
    
    Encapsulates middleware class with its initialization arguments
    for use in application middleware stack configuration.
    """
    
    def __init__(self, cls: type, *args: Any, **kwargs: Any) -> None:
        """
        Initialize middleware configuration.
        
        Args:
            cls: Middleware class
            *args: Positional arguments for middleware constructor
            **kwargs: Keyword arguments for middleware constructor
        """
        self.cls = cls
        self.args = args
        self.kwargs = kwargs

Base HTTP Middleware

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Scope, Receive, Send
from typing import Callable, Awaitable

class BaseHTTPMiddleware:
    """
    Base class for HTTP middleware.
    
    Provides a simplified interface for creating middleware that processes
    HTTP requests and responses. Handles ASGI protocol details and provides
    a clean dispatch method to override.
    """
    
    def __init__(self, app: ASGIApp, dispatch: Callable | None = None) -> None:
        """
        Initialize HTTP middleware.
        
        Args:
            app: Next ASGI application in the stack
            dispatch: Optional custom dispatch function
        """
        self.app = app
        if dispatch is not None:
            self.dispatch_func = dispatch
    
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """ASGI application interface."""
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        # Process HTTP request through dispatch
        async def call_next(request: Request) -> Response:
            # ... implementation details
        
        request = Request(scope, receive)
        response = await self.dispatch(request, call_next)
        await response(scope, receive, send)
    
    async def dispatch(
        self,
        request: Request,
        call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        """
        Process HTTP request and response.
        
        Override this method to implement middleware logic.
        
        Args:
            request: HTTP request object
            call_next: Function to call next middleware/endpoint
            
        Returns:
            Response: HTTP response object
        """
        # Default implementation just calls next middleware
        response = await call_next(request)
        return response

Built-in Middleware

CORS Middleware

from starlette.middleware.cors import CORSMiddleware
from typing import Sequence

class CORSMiddleware:
    """
    Cross-Origin Resource Sharing (CORS) middleware.
    
    Handles CORS headers for cross-origin requests, including:
    - Preflight request handling
    - Origin validation
    - Credential support
    - Custom headers and methods
    """
    
    def __init__(
        self,
        app: ASGIApp,
        allow_origins: Sequence[str] = (),
        allow_methods: Sequence[str] = ("GET",),
        allow_headers: Sequence[str] = (),
        allow_credentials: bool = False,
        allow_origin_regex: str | None = None,
        expose_headers: Sequence[str] = (),
        max_age: int = 600,
    ) -> None:
        """
        Initialize CORS middleware.
        
        Args:
            app: ASGI application
            allow_origins: Allowed origin URLs ("*" for all)
            allow_methods: Allowed HTTP methods
            allow_headers: Allowed request headers
            allow_credentials: Allow credentials (cookies, auth headers)
            allow_origin_regex: Regex pattern for allowed origins
            expose_headers: Headers to expose to client
            max_age: Preflight cache duration in seconds
        """
    
    def is_allowed_origin(self, origin: str) -> bool:
        """Check if origin is allowed."""
    
    def preflight_response(self, request_headers: Headers) -> Response:
        """Handle CORS preflight request."""
    
    async def simple_response(
        self,
        scope: Scope,
        receive: Receive, 
        send: Send,
        request_headers: Headers
    ) -> None:
        """Handle simple CORS request."""

# CORS constants
ALL_METHODS = (
    "DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"
)

SAFELISTED_HEADERS = {
    "accept", "accept-language", "content-language", "content-type"
}

GZip Compression Middleware

from starlette.middleware.gzip import GZipMiddleware

class GZipMiddleware:
    """
    GZip compression middleware.
    
    Automatically compresses response bodies when:
    - Client accepts gzip encoding  
    - Response size exceeds minimum threshold
    - Content-Type is compressible
    """
    
    def __init__(
        self,
        app: ASGIApp,
        minimum_size: int = 500,
        compresslevel: int = 9,
    ) -> None:
        """
        Initialize GZip middleware.
        
        Args:
            app: ASGI application
            minimum_size: Minimum response size to compress (bytes)
            compresslevel: Compression level (1-9, higher = better compression)
        """

Session Middleware

from starlette.middleware.sessions import SessionMiddleware

class SessionMiddleware:
    """
    Cookie-based session middleware.
    
    Provides encrypted session storage using signed cookies.
    Sessions are available as request.session dictionary.
    """
    
    def __init__(
        self,
        app: ASGIApp,
        secret_key: str,
        session_cookie: str = "session",
        max_age: int = 14 * 24 * 60 * 60,  # 14 days
        path: str = "/",
        same_site: str = "lax",
        https_only: bool = False,
        domain: str | None = None,
    ) -> None:
        """
        Initialize session middleware.
        
        Args:
            app: ASGI application
            secret_key: Secret key for signing cookies (keep secret!)
            session_cookie: Cookie name for session data
            max_age: Session lifetime in seconds
            path: Cookie path
            same_site: SameSite cookie policy
            https_only: Only send cookies over HTTPS
            domain: Cookie domain
        """

Authentication Middleware

from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.authentication import AuthenticationBackend
from starlette.requests import Request
from starlette.responses import Response

class AuthenticationMiddleware:
    """
    Authentication middleware using pluggable backends.
    
    Authenticates requests and populates request.user and request.auth
    based on configured authentication backends.
    """
    
    def __init__(
        self,
        app: ASGIApp,
        backend: AuthenticationBackend,
        on_error: Callable[[Request, Exception], Response] | None = None,
    ) -> None:
        """
        Initialize authentication middleware.
        
        Args:
            app: ASGI application
            backend: Authentication backend implementation
            on_error: Custom error handler for authentication failures
        """
    
    def default_on_error(self, conn: Request, exc: Exception) -> Response:
        """Default authentication error handler."""

Error Handling Middleware

from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse, HTMLResponse

class ServerErrorMiddleware:
    """
    Server error handling middleware.
    
    Catches unhandled exceptions and returns appropriate error responses.
    In debug mode, provides detailed error pages with stack traces.
    """
    
    def __init__(
        self,
        app: ASGIApp,
        handler: Callable[[Request, Exception], Response] | None = None,
        debug: bool = False,
    ) -> None:
        """
        Initialize server error middleware.
        
        Args:
            app: ASGI application
            handler: Custom error handler function
            debug: Enable debug mode with detailed error pages
        """
    
    def debug_response(self, request: Request, exc: Exception) -> Response:
        """Generate debug error response with stack trace."""
    
    def error_response(self, request: Request, exc: Exception) -> Response:
        """Generate production error response."""

Exception Middleware

from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.exceptions import HTTPException

class ExceptionMiddleware:
    """
    Exception handling middleware for HTTP and WebSocket exceptions.
    
    Provides centralized exception handling with support for:
    - HTTP exceptions (HTTPException)
    - WebSocket exceptions (WebSocketException)  
    - Custom exception handlers
    - Status code handlers
    """
    
    def __init__(
        self,
        app: ASGIApp,
        handlers: dict[Any, Callable] | None = None,
        debug: bool = False,
    ) -> None:
        """
        Initialize exception middleware.
        
        Args:
            app: ASGI application
            handlers: Exception handler mapping
            debug: Enable debug mode
        """
    
    def add_exception_handler(
        self,
        exc_class_or_status_code: type[Exception] | int,
        handler: Callable,
    ) -> None:
        """Add exception handler for specific exception or status code."""
    
    async def http_exception(self, request: Request, exc: HTTPException) -> Response:
        """Handle HTTP exceptions."""
    
    async def websocket_exception(self, websocket: WebSocket, exc: WebSocketException) -> None:
        """Handle WebSocket exceptions."""

Security Middleware

from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware

class HTTPSRedirectMiddleware:
    """
    HTTPS redirect middleware.
    
    Automatically redirects HTTP requests to HTTPS for security.
    """
    
    def __init__(self, app: ASGIApp) -> None:
        """Initialize HTTPS redirect middleware."""

class TrustedHostMiddleware:
    """
    Trusted host validation middleware.
    
    Validates Host header against allowed hosts to prevent
    Host header injection attacks.
    """
    
    def __init__(
        self,
        app: ASGIApp,
        allowed_hosts: Sequence[str] | None = None,
        www_redirect: bool = True,
    ) -> None:
        """
        Initialize trusted host middleware.
        
        Args:
            app: ASGI application
            allowed_hosts: List of allowed hostnames (* for wildcard)
            www_redirect: Redirect www subdomain to apex domain
        """

Middleware Configuration

Application-level Middleware

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.sessions import SessionMiddleware

# Configure middleware stack
middleware = [
    # Order matters - first in list wraps all others
    Middleware(CORSMiddleware, 
              allow_origins=["https://example.com"],
              allow_methods=["GET", "POST"],
              allow_headers=["*"]),
    Middleware(GZipMiddleware, minimum_size=1000),
    Middleware(SessionMiddleware, secret_key="secret-key-here"),
]

app = Starlette(
    routes=routes,
    middleware=middleware,
)

# Add middleware after initialization
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(TrustedHostMiddleware, 
                  allowed_hosts=["example.com", "*.example.com"])

Route-specific Middleware

from starlette.routing import Route
from starlette.middleware import Middleware

# Middleware for specific routes
class TimingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        response.headers["X-Process-Time"] = str(process_time)
        return response

routes = [
    Route("/", homepage),  # No extra middleware
    Route("/api/data", api_endpoint, middleware=[
        Middleware(TimingMiddleware),
        Middleware(CacheMiddleware, ttl=300),
    ]),
]

Custom Middleware

Simple Function Middleware

from starlette.types import ASGIApp, Scope, Receive, Send

def logging_middleware(app: ASGIApp) -> ASGIApp:
    """Simple ASGI middleware function."""
    
    async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] == "http":
            print(f"Request: {scope['method']} {scope['path']}")
        
        await app(scope, receive, send)
    
    return middleware

# Usage
app = Starlette(routes=routes)
app.add_middleware(logging_middleware)

Class-based Middleware

import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

class TimingMiddleware(BaseHTTPMiddleware):
    """Middleware to add timing headers to responses."""
    
    async def dispatch(self, request: Request, call_next) -> Response:
        start_time = time.time()
        
        # Process request
        response = await call_next(request)
        
        # Add timing information
        process_time = time.time() - start_time
        response.headers["X-Process-Time"] = str(process_time)
        
        return response

class RateLimitMiddleware(BaseHTTPMiddleware):
    """Simple rate limiting middleware."""
    
    def __init__(self, app, calls_per_minute: int = 60):
        super().__init__(app)
        self.calls_per_minute = calls_per_minute
        self.clients = {}  # In production, use Redis or similar
    
    async def dispatch(self, request: Request, call_next) -> Response:
        client_ip = request.client.host if request.client else "unknown"
        now = time.time()
        
        # Clean old entries
        minute_ago = now - 60
        self.clients = {
            ip: times for ip, times in self.clients.items()
            if any(t > minute_ago for t in times)
        }
        
        # Check rate limit
        if client_ip not in self.clients:
            self.clients[client_ip] = []
        
        recent_calls = [t for t in self.clients[client_ip] if t > minute_ago]
        
        if len(recent_calls) >= self.calls_per_minute:
            return JSONResponse(
                {"error": "Rate limit exceeded"},
                status_code=429,
                headers={"Retry-After": "60"}
            )
        
        # Add current call
        self.clients[client_ip].append(now)
        
        # Process request
        response = await call_next(request)
        response.headers["X-RateLimit-Remaining"] = str(
            self.calls_per_minute - len(recent_calls) - 1
        )
        
        return response

class AuthenticationMiddleware(BaseHTTPMiddleware):
    """Custom authentication middleware."""
    
    async def dispatch(self, request: Request, call_next) -> Response:
        # Skip authentication for public endpoints
        if request.url.path in ["/", "/health", "/docs"]:
            return await call_next(request)
        
        # Check for API key
        api_key = request.headers.get("X-API-Key")
        if not api_key:
            return JSONResponse(
                {"error": "API key required"},
                status_code=401,
                headers={"WWW-Authenticate": "ApiKey"}
            )
        
        # Validate API key
        user = await validate_api_key(api_key)
        if not user:
            return JSONResponse(
                {"error": "Invalid API key"},
                status_code=401
            )
        
        # Add user to request state
        request.state.user = user
        
        return await call_next(request)

async def validate_api_key(api_key: str):
    # Implement your API key validation logic
    # Return user object or None
    pass

ASGI Middleware

class RequestIDMiddleware:
    """Add unique request ID to each request."""
    
    def __init__(self, app: ASGIApp):
        self.app = app
    
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] == "http":
            # Generate unique request ID
            request_id = str(uuid.uuid4())
            
            # Add to scope
            scope["request_id"] = request_id
            
            # Wrap send to add header to response
            async def send_wrapper(message):
                if message["type"] == "http.response.start":
                    headers = list(message.get("headers", []))
                    headers.append([b"x-request-id", request_id.encode()])
                    message = {**message, "headers": headers}
                await send(message)
            
            await self.app(scope, receive, send_wrapper)
        else:
            await self.app(scope, receive, send)

class CompressionMiddleware:
    """Custom compression middleware with multiple algorithms."""
    
    def __init__(self, app: ASGIApp, algorithms: list[str] = None):
        self.app = app
        self.algorithms = algorithms or ["gzip", "br"]
    
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        # Check if client accepts compression
        headers = dict(scope.get("headers", []))
        accept_encoding = headers.get(b"accept-encoding", b"").decode()
        
        selected_encoding = None
        for encoding in self.algorithms:
            if encoding in accept_encoding:
                selected_encoding = encoding
                break
        
        if not selected_encoding:
            await self.app(scope, receive, send)
            return
        
        # Wrap response to compress body
        response_complete = False
        
        async def send_wrapper(message):
            nonlocal response_complete
            
            if message["type"] == "http.response.start":
                # Add compression header
                headers = list(message.get("headers", []))
                headers.append([b"content-encoding", selected_encoding.encode()])
                message = {**message, "headers": headers}
                
            elif message["type"] == "http.response.body":
                if not response_complete:
                    body = message.get("body", b"")
                    if body:
                        # Compress body based on selected encoding
                        if selected_encoding == "gzip":
                            body = gzip.compress(body)
                        elif selected_encoding == "br":
                            body = brotli.compress(body)
                        
                        message = {**message, "body": body}
            
            await send(message)
        
        await self.app(scope, receive, send_wrapper)

Middleware Best Practices

Error Handling in Middleware

class SafeMiddleware(BaseHTTPMiddleware):
    """Middleware with proper error handling."""
    
    async def dispatch(self, request: Request, call_next) -> Response:
        try:
            # Pre-processing
            await self.before_request(request)
            
            # Call next middleware/endpoint
            response = await call_next(request)
            
            # Post-processing
            await self.after_request(request, response)
            
            return response
            
        except Exception as e:
            # Handle middleware errors
            return await self.handle_error(request, e)
    
    async def before_request(self, request: Request):
        """Override for pre-processing logic."""
        pass
    
    async def after_request(self, request: Request, response: Response):
        """Override for post-processing logic."""
        pass
    
    async def handle_error(self, request: Request, exc: Exception) -> Response:
        """Override for error handling."""
        print(f"Middleware error: {exc}")
        return JSONResponse(
            {"error": "Internal server error"},
            status_code=500
        )

Conditional Middleware

class ConditionalMiddleware(BaseHTTPMiddleware):
    """Middleware that applies conditionally."""
    
    def __init__(self, app, condition_func: Callable[[Request], bool]):
        super().__init__(app)
        self.should_apply = condition_func
    
    async def dispatch(self, request: Request, call_next) -> Response:
        # Check if middleware should apply
        if not self.should_apply(request):
            return await call_next(request)
        
        # Apply middleware logic
        return await self.process_request(request, call_next)
    
    async def process_request(self, request: Request, call_next) -> Response:
        # Override this method
        return await call_next(request)

# Usage examples
def is_api_request(request: Request) -> bool:
    return request.url.path.startswith("/api/")

def is_authenticated(request: Request) -> bool:
    return "authorization" in request.headers

# Apply different middleware based on conditions
app.add_middleware(
    ConditionalMiddleware,
    condition_func=is_api_request
)

Middleware Dependencies

class DatabaseMiddleware(BaseHTTPMiddleware):
    """Provide database connection to requests."""
    
    async def dispatch(self, request: Request, call_next) -> Response:
        # Create database connection
        async with database.connection() as conn:
            # Add to request state
            request.state.db = conn
            
            # Process request with database available
            response = await call_next(request)
            
        return response

class CacheMiddleware(BaseHTTPMiddleware):
    """Cache responses (requires database middleware)."""
    
    async def dispatch(self, request: Request, call_next) -> Response:
        # This middleware depends on DatabaseMiddleware
        if not hasattr(request.state, 'db'):
            raise RuntimeError("CacheMiddleware requires DatabaseMiddleware")
        
        cache_key = f"response:{request.url}"
        
        # Try cache first
        cached = await request.state.db.get_cache(cache_key)
        if cached:
            return JSONResponse(cached)
        
        # Generate response
        response = await call_next(request)
        
        # Cache response if successful
        if response.status_code == 200:
            await request.state.db.set_cache(cache_key, response.body)
        
        return response

# Order matters - DatabaseMiddleware must come first
middleware = [
    Middleware(DatabaseMiddleware),
    Middleware(CacheMiddleware),
]

Starlette's middleware system provides a powerful way to implement cross-cutting concerns with a clean, composable architecture that supports both simple and complex use cases.

Install with Tessl CLI

npx tessl i tessl/pypi-starlette

docs

authentication.md

core-application.md

data-structures.md

exceptions-status.md

index.md

middleware.md

requests-responses.md

routing.md

static-files.md

testing.md

websockets.md

tile.json