Litestar is a powerful, flexible yet opinionated ASGI web framework specifically focused on building high-performance APIs.
—
Middleware components for request/response processing including authentication, CORS, compression, rate limiting, and custom middleware. Litestar supports both function-based and class-based middleware with a flexible pipeline architecture.
Base classes for implementing custom middleware components.
class AbstractMiddleware:
def __init__(self, app: ASGIApp):
"""
Base middleware class.
Parameters:
- app: ASGI application to wrap
"""
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
ASGI application interface.
Parameters:
- scope: ASGI scope
- receive: ASGI receive callable
- send: ASGI send callable
"""
await self.app(scope, receive, send)
class ASGIMiddleware(AbstractMiddleware):
def __init__(self, app: ASGIApp, middleware: type[AbstractMiddleware], **kwargs: Any):
"""
Wrapper for ASGI middleware.
Parameters:
- app: ASGI application
- middleware: Middleware class to instantiate
- **kwargs: Arguments for middleware constructor
"""Configuration classes for defining and organizing middleware.
class DefineMiddleware:
def __init__(
self,
middleware: type[AbstractMiddleware] | Callable[..., ASGIApp],
*args: Any,
**kwargs: Any,
):
"""
Configure middleware with arguments.
Parameters:
- middleware: Middleware class or factory function
- *args: Positional arguments for middleware
- **kwargs: Keyword arguments for middleware
"""
@property
def middleware(self) -> type[AbstractMiddleware] | Callable[..., ASGIApp]:
"""Get the middleware class or factory."""
@property
def args(self) -> tuple[Any, ...]:
"""Get positional arguments."""
@property
def kwargs(self) -> dict[str, Any]:
"""Get keyword arguments."""
def middleware(
middleware_class: type[AbstractMiddleware],
*args: Any,
**kwargs: Any,
) -> Callable[[T], T]:
"""
Decorator to apply middleware to route handlers.
Parameters:
- middleware_class: Middleware class to apply
- *args: Positional arguments for middleware
- **kwargs: Keyword arguments for middleware
Returns:
Decorator function
"""Pre-built middleware components for common use cases.
class CORSMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
*,
allow_origins: Sequence[str] = ("*",),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
expose_headers: Sequence[str] = (),
max_age: int = 600,
):
"""
CORS (Cross-Origin Resource Sharing) middleware.
Parameters:
- app: ASGI application
- allow_origins: Allowed origin patterns
- allow_methods: Allowed HTTP methods
- allow_headers: Allowed request headers
- allow_credentials: Allow credentials in requests
- allow_origin_regex: Regex pattern for allowed origins
- expose_headers: Headers to expose to client
- max_age: Preflight cache duration in seconds
"""
class CompressionMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
*,
backend: Literal["gzip", "brotli"] = "gzip",
minimum_size: int = 500,
exclude_opt_key: str = "skip_compression",
compression_facade: CompressionFacade | None = None,
):
"""
Response compression middleware.
Parameters:
- app: ASGI application
- backend: Compression algorithm
- minimum_size: Minimum response size to compress
- exclude_opt_key: Key to exclude routes from compression
- compression_facade: Custom compression implementation
"""
class SessionMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
store: Store,
*,
secret: str | Secret,
key: str = "session",
max_age: int = 1209600, # 14 days
path: str = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = True,
samesite: Literal["lax", "strict", "none"] = "lax",
exclude: str | list[str] | None = None,
exclude_opt_key: str = "skip_session",
):
"""
Session management middleware.
Parameters:
- app: ASGI application
- store: Session storage backend
- secret: Secret key for session signing
- key: Session cookie name
- max_age: Session lifetime in seconds
- path: Cookie path
- domain: Cookie domain
- secure: Use secure cookies
- httponly: Use HTTP-only cookies
- samesite: SameSite cookie attribute
- exclude: Paths to exclude from session handling
- exclude_opt_key: Route option key for exclusion
"""
class TrustedHostMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
allowed_hosts: Sequence[str],
*,
exclude_opt_key: str = "skip_allowed_hosts_check",
exclude: str | list[str] | None = None,
redirect_domains: Sequence[str] | None = None,
):
"""
Trusted host validation middleware.
Parameters:
- app: ASGI application
- allowed_hosts: List of allowed host patterns
- exclude_opt_key: Route option key for exclusion
- exclude: Paths to exclude from host checking
- redirect_domains: Domains to redirect to if host is not allowed
"""
class CSRFMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
*,
secret: str | Secret,
cookie_name: str = "csrftoken",
cookie_path: str = "/",
cookie_domain: str | None = None,
cookie_secure: bool = False,
cookie_httponly: bool = False,
cookie_samesite: Literal["lax", "strict", "none"] = "lax",
header_name: str = "x-csrftoken",
safe_methods: set[str] | None = None,
exclude: str | list[str] | None = None,
exclude_opt_key: str = "skip_csrf_protection",
):
"""
CSRF (Cross-Site Request Forgery) protection middleware.
Parameters:
- app: ASGI application
- secret: Secret key for token generation
- cookie_name: CSRF token cookie name
- cookie_path: Cookie path
- cookie_domain: Cookie domain
- cookie_secure: Use secure cookies
- cookie_httponly: Use HTTP-only cookies
- cookie_samesite: SameSite cookie attribute
- header_name: HTTP header containing CSRF token
- safe_methods: HTTP methods that don't require CSRF protection
- exclude: Paths to exclude from CSRF protection
- exclude_opt_key: Route option key for exclusion
"""Middleware for handling exceptions and generating appropriate error responses.
class ExceptionHandlerMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
*,
debug: bool = False,
exception_handlers: dict[type[Exception], ExceptionHandler] | None = None,
):
"""
Exception handling middleware.
Parameters:
- app: ASGI application
- debug: Enable debug mode with detailed error info
- exception_handlers: Custom exception handlers
"""
async def handle_exception(
self,
request: Request,
exc: Exception,
) -> Response:
"""
Handle an exception and generate appropriate response.
Parameters:
- request: HTTP request that caused the exception
- exc: Exception that was raised
Returns:
Response object for the exception
"""Middleware for implementing rate limiting and throttling.
class RateLimitMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
store: Store,
*,
rate_limit: RateLimit | None = None,
exclude: str | list[str] | None = None,
exclude_opt_key: str = "skip_rate_limiting",
):
"""
Rate limiting middleware.
Parameters:
- app: ASGI application
- store: Storage backend for rate limit data
- rate_limit: Default rate limit configuration
- exclude: Paths to exclude from rate limiting
- exclude_opt_key: Route option key for exclusion
"""
class RateLimit:
def __init__(
self,
rate: int,
per: timedelta,
*,
identifier: Callable[[Request], str] | None = None,
key_formatter: Callable[[str, str], str] | None = None,
):
"""
Rate limit configuration.
Parameters:
- rate: Number of requests allowed
- per: Time window for rate limiting
- identifier: Function to identify clients
- key_formatter: Function to format cache keys
"""Common patterns for implementing custom middleware.
class ProcessingTimeMiddleware(AbstractMiddleware):
"""Example middleware that adds processing time header."""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
process_time = time.time() - start_time
message.setdefault("headers", [])
message["headers"].append([
b"x-process-time",
str(process_time).encode()
])
await send(message)
await self.app(scope, receive, send_wrapper)
class LoggingMiddleware(AbstractMiddleware):
"""Example middleware for request/response logging."""
def __init__(self, app: ASGIApp, logger: Logger | None = None):
super().__init__(app)
self.logger = logger or logging.getLogger(__name__)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
self.logger.info(f"Request: {request.method} {request.url}")
start_time = time.time()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
duration = time.time() - start_time
status_code = message["status"]
self.logger.info(
f"Response: {status_code} ({duration:.3f}s)"
)
await send(message)
await self.app(scope, receive, send_wrapper)Protocol definition for middleware implementations.
class MiddlewareProtocol(Protocol):
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
"""Initialize middleware with ASGI app."""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI callable interface."""from litestar import Litestar, get
from litestar.middleware import DefineMiddleware
from litestar.middleware.cors import CORSConfig
from litestar.middleware.compression import CompressionConfig
@get("/api/data")
def get_data() -> dict:
return {"message": "Hello World"}
app = Litestar(
route_handlers=[get_data],
cors_config=CORSConfig(
allow_origins=["https://frontend.example.com"],
allow_methods=["GET", "POST"],
allow_headers=["content-type"],
),
compression_config=CompressionConfig(
backend="gzip",
minimum_size=500,
),
)import time
from litestar.middleware.base import AbstractMiddleware
class RequestTimingMiddleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
duration = time.time() - start_time
headers = message.setdefault("headers", [])
headers.append([
b"x-response-time",
f"{duration:.3f}".encode()
])
await send(message)
await self.app(scope, receive, send_wrapper)
app = Litestar(
route_handlers=[get_data],
middleware=[RequestTimingMiddleware],
)from litestar.middleware.authentication import AbstractAuthenticationMiddleware
class APIKeyMiddleware(AbstractAuthenticationMiddleware):
def __init__(self, app: ASGIApp, api_keys: set[str]):
super().__init__(app)
self.api_keys = api_keys
async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
api_key = connection.headers.get("X-API-Key")
if api_key in self.api_keys:
return AuthenticationResult(user={"api_key": api_key})
return AuthenticationResult()
@get(
"/admin/users",
middleware=[DefineMiddleware(APIKeyMiddleware, {"secret-key-123"})]
)
def admin_users() -> list[dict]:
return [{"id": 1, "name": "admin"}]from litestar.stores.redis import RedisStore
from litestar.middleware.session import SessionMiddleware
# Redis session store
redis_store = RedisStore(url="redis://localhost:6379")
session_middleware = SessionMiddleware(
store=redis_store,
secret="session-secret-key",
key="sessionid",
max_age=3600, # 1 hour
httponly=True,
secure=True,
samesite="strict",
)
@get("/profile")
def get_profile(request: Request) -> dict:
user_id = request.session.get("user_id")
if not user_id:
raise NotAuthorizedException("Login required")
return {"user_id": user_id, "session_data": dict(request.session)}
@post("/login")
def login(request: Request, data: dict) -> dict:
# Authenticate user (simplified)
if data.get("username") == "alice":
request.session["user_id"] = 123
request.session["username"] = "alice"
return {"status": "logged in"}
raise NotAuthorizedException("Invalid credentials")
app = Litestar(
route_handlers=[get_profile, login],
middleware=[session_middleware],
)from litestar.middleware.cors import CORSConfig
cors_config = CORSConfig(
allow_origins=["https://app.example.com", "https://admin.example.com"],
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
allow_credentials=True,
expose_headers=["X-Total-Count"],
max_age=86400, # 24 hours
)
app = Litestar(
route_handlers=[...],
cors_config=cors_config,
)from litestar import Dependency
from litestar.middleware.base import AbstractMiddleware
class DatabaseMiddleware(AbstractMiddleware):
def __init__(self, app: ASGIApp, db_pool: Any):
super().__init__(app)
self.db_pool = db_pool
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
# Add database connection to scope
async with self.db_pool.acquire() as conn:
scope["db_connection"] = conn
await self.app(scope, receive, send)
else:
await self.app(scope, receive, send)
def get_db_connection(scope: dict) -> Any:
return scope.get("db_connection")
@get("/users", dependencies={"db": Dependency(get_db_connection)})
async def get_users(db: Any) -> list[dict]:
# Use database connection from middleware
result = await db.fetch("SELECT * FROM users")
return [dict(row) for row in result]from litestar.middleware.exceptions import ExceptionHandlerMiddleware
from litestar.exceptions import HTTPException
import logging
logger = logging.getLogger(__name__)
class CustomExceptionMiddleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
try:
await self.app(scope, receive, send)
except HTTPException:
# Let HTTP exceptions pass through
raise
except Exception as exc:
logger.exception("Unhandled exception occurred")
# Convert to HTTP exception
http_exc = HTTPException(
detail="Internal server error",
status_code=500,
extra={"error_id": str(uuid.uuid4())}
)
# Send error response
response = http_exc.to_response()
await response(scope, receive, send)
app = Litestar(
route_handlers=[...],
middleware=[CustomExceptionMiddleware],
debug=False,
)class ConditionalMiddleware(AbstractMiddleware):
def __init__(self, app: ASGIApp, condition: Callable[[Request], bool]):
super().__init__(app)
self.condition = condition
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
if self.condition(request):
# Apply middleware logic
start_time = time.time()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
duration = time.time() - start_time
message.setdefault("headers", []).append([
b"x-conditional-timing",
f"{duration:.3f}".encode()
])
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
# Only apply timing to API routes
def is_api_route(request: Request) -> bool:
return request.url.path.startswith("/api/")
app = Litestar(
route_handlers=[...],
middleware=[DefineMiddleware(ConditionalMiddleware, is_api_route)],
)# ASGI types
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
Scope = dict[str, Any]
Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]
Message = dict[str, Any]
# Middleware types
MiddlewareType = type[AbstractMiddleware] | Callable[..., ASGIApp]
Middleware = MiddlewareType | DefineMiddleware
# Exception handler type
ExceptionHandler = Callable[[Request, Exception], Response | Awaitable[Response]]
# Store interface for session/rate limiting
class Store(Protocol):
async def get(self, key: str) -> Any | None: ...
async def set(self, key: str, value: Any, expires_in: int | None = None) -> None: ...
async def delete(self, key: str) -> None: ...
# Rate limiting types
RateLimitIdentifier = Callable[[Request], str]
RateLimitKeyFormatter = Callable[[str, str], str]Install with Tessl CLI
npx tessl i tessl/pypi-litestar