Middleware for Starlette that allows you to store and access the context data of a request.
Extensible plugin architecture for extracting data from requests and enriching responses. The plugin system provides a clean separation between data extraction logic and context management, with built-in plugins for common use cases.
Foundation classes for creating custom plugins that integrate with the middleware system.
class Plugin(metaclass=abc.ABCMeta):
"""
Base class for building those plugins to extract things from request.
One plugin should be responsible for extracting one thing.
key: the key that allows to access value in headers
"""
key: str # Header key or context key for this plugin
async def extract_value_from_header_by_key(
self, request: Union[Request, HTTPConnection]
) -> Optional[Any]:
"""
Extract value from request headers using plugin's key.
Parameters:
- request: Request or HTTPConnection object
Returns:
Optional[Any]: Header value or None if not found
"""
async def process_request(
self, request: Union[Request, HTTPConnection]
) -> Optional[Any]:
"""
Runs always on request.
Extracts value from header by default.
Parameters:
- request: Request or HTTPConnection object
Returns:
Optional[Any]: Processed value for context storage
"""
async def enrich_response(self, arg: Union[Response, Message]) -> None:
"""
Runs always on response.
Does nothing by default.
Parameters:
- arg: Response object (ContextMiddleware) or Message dict (RawContextMiddleware)
"""
class PluginUUIDBase(Plugin):
"""Base class for UUID-based plugins with validation and generation."""
uuid_functions_mapper = {4: uuid.uuid4} # Supported UUID versions
def __init__(
self,
force_new_uuid: bool = False,
version: int = 4,
validate: bool = True,
error_response: Optional[Response] = None
):
"""
Initialize UUID plugin.
Parameters:
- force_new_uuid: Always generate new UUID, ignore request header
- version: UUID version (currently only 4 supported)
- validate: Validate UUID format if present in request
- error_response: Custom response for validation errors
Raises:
ConfigurationError: If unsupported UUID version specified
"""
def validate_uuid(self, uuid_to_validate: str) -> None:
"""
Validate UUID format.
Parameters:
- uuid_to_validate: UUID string to validate
Raises:
WrongUUIDError: If UUID format is invalid
"""
def get_new_uuid(self) -> str:
"""
Generate new UUID.
Returns:
str: New UUID as hex string
"""
async def extract_value_from_header_by_key(
self, request: Union[Request, HTTPConnection]
) -> Optional[str]:
"""
Extract or generate UUID from request.
Parameters:
- request: Request or HTTPConnection object
Returns:
Optional[str]: UUID string
Raises:
WrongUUIDError: If validation enabled and UUID is invalid
"""
async def enrich_response(self, arg: Any) -> None:
"""
Add UUID to response headers.
Parameters:
- arg: Response object or Message dict
"""Ready-to-use plugins for common header extraction and processing scenarios.
class RequestIdPlugin(PluginUUIDBase):
"""Manages request IDs with X-Request-ID header."""
key = HeaderKeys.request_id # "X-Request-ID"
class CorrelationIdPlugin(PluginUUIDBase):
"""Manages correlation IDs with X-Correlation-ID header."""
key = HeaderKeys.correlation_id # "X-Correlation-ID"
class ApiKeyPlugin(Plugin):
"""Extracts API key from X-API-Key header."""
key = HeaderKeys.api_key # "X-API-Key"
class UserAgentPlugin(Plugin):
"""Extracts User-Agent header."""
key = HeaderKeys.user_agent # "User-Agent"
class ForwardedForPlugin(Plugin):
"""Extracts X-Forwarded-For header."""
key = HeaderKeys.forwarded_for # "X-Forwarded-For"
class DateHeaderPlugin(Plugin):
"""Parses Date header in RFC1123 format."""
key = HeaderKeys.date # "Date"
def __init__(
self,
*args: Any,
error_response: Optional[Response] = Response(status_code=400)
) -> None:
"""
Initialize date header plugin.
Parameters:
- error_response: Response to return on date format errors
"""
@staticmethod
def rfc1123_to_dt(s: str) -> datetime.datetime:
"""
Convert RFC1123 date string to datetime.
Parameters:
- s: RFC1123 formatted date string
Returns:
datetime.datetime: Parsed datetime object
Raises:
ValueError: If date format is invalid
"""
async def process_request(
self, request: Union[Request, HTTPConnection]
) -> Optional[datetime.datetime]:
"""
Parse Date header to datetime.
Parameters:
- request: Request or HTTPConnection object
Returns:
Optional[datetime.datetime]: Parsed date or None if not present
Raises:
DateFormatError: If date format is invalid
"""from starlette_context.middleware import ContextMiddleware
from starlette_context.plugins import RequestIdPlugin, UserAgentPlugin
from starlette_context import context
# Setup middleware with plugins
app.add_middleware(
ContextMiddleware,
plugins=[
RequestIdPlugin(),
UserAgentPlugin()
]
)
# Access plugin data in handlers
async def my_handler(request):
request_id = context["X-Request-ID"] # From RequestIdPlugin
user_agent = context["User-Agent"] # From UserAgentPlugin
return {"request_id": request_id, "user_agent": user_agent}from starlette_context.plugins import RequestIdPlugin, CorrelationIdPlugin
from starlette.responses import JSONResponse
# Always generate new request ID
request_id_plugin = RequestIdPlugin(force_new_uuid=True)
# Use existing correlation ID or generate new one, with validation
correlation_plugin = CorrelationIdPlugin(
validate=True,
error_response=JSONResponse(
{"error": "Invalid correlation ID format"},
status_code=400
)
)
app.add_middleware(
ContextMiddleware,
plugins=[request_id_plugin, correlation_plugin]
)from starlette_context.plugins import DateHeaderPlugin
from starlette_context import context
import datetime
# Parse RFC1123 date headers
date_plugin = DateHeaderPlugin()
app.add_middleware(ContextMiddleware, plugins=[date_plugin])
async def handler(request):
date_value = context.get("Date") # datetime.datetime object or None
if date_value:
formatted_date = date_value.strftime("%Y-%m-%d %H:%M:%S")
return {"parsed_date": formatted_date}
return {"parsed_date": None}from starlette_context.plugins import Plugin
from starlette_context import context
import json
class CustomHeaderPlugin(Plugin):
key = "X-Custom-Data"
async def process_request(self, request):
# Extract and process header
raw_value = await self.extract_value_from_header_by_key(request)
if raw_value:
try:
# Parse JSON data
return json.loads(raw_value)
except json.JSONDecodeError:
return {"error": "Invalid JSON in header"}
return None
async def enrich_response(self, response):
# Add processed data to response
custom_data = context.get(self.key)
if custom_data and hasattr(response, 'headers'):
response.headers["X-Processed-Data"] = json.dumps(custom_data)
# Use custom plugin
app.add_middleware(
ContextMiddleware,
plugins=[CustomHeaderPlugin()]
)from starlette_context.plugins import PluginUUIDBase
from starlette_context.header_keys import HeaderKeys
class TraceIdPlugin(PluginUUIDBase):
key = "X-Trace-ID"
def __init__(self, **kwargs):
# Always validate trace IDs, generate if missing
super().__init__(
force_new_uuid=False,
validate=True,
**kwargs
)
async def enrich_response(self, response):
# Always add trace ID to response
await super().enrich_response(response)
# Add to custom header as well
trace_id = context[self.key]
if hasattr(response, 'headers'):
response.headers["X-Trace-Context"] = f"trace-id={trace_id}"
app.add_middleware(
ContextMiddleware,
plugins=[TraceIdPlugin()]
)from starlette_context.plugins import DateHeaderPlugin
from starlette_context.errors import DateFormatError
from starlette.responses import JSONResponse
# Custom error response for date parsing
error_response = JSONResponse(
{
"error": "Invalid date format",
"expected": "RFC1123 format (e.g., 'Wed, 01 Jan 2020 04:27:12 GMT')"
},
status_code=422
)
date_plugin = DateHeaderPlugin(error_response=error_response)
app.add_middleware(ContextMiddleware, plugins=[date_plugin])from starlette_context.plugins import Plugin
class MultiHeaderPlugin(Plugin):
key = "combined_headers"
def __init__(self, header_keys):
self.header_keys = header_keys
async def process_request(self, request):
headers = {}
for header_key in self.header_keys:
value = request.headers.get(header_key)
if value:
headers[header_key] = value
return headers if headers else None
# Extract multiple headers into single context entry
multi_plugin = MultiHeaderPlugin([
"X-Forwarded-For",
"X-Real-IP",
"X-Client-ID"
])
app.add_middleware(ContextMiddleware, plugins=[multi_plugin])
# Access in handler
async def handler(request):
headers = context.get("combined_headers", {})
client_ip = (
headers.get("X-Forwarded-For") or
headers.get("X-Real-IP") or
"unknown"
)
return {"client_ip": client_ip}process_request# Simple header extraction
class SimplePlugin(Plugin):
key = "X-My-Header"
# Uses default implementation
# Header processing
class ProcessingPlugin(Plugin):
key = "X-Complex-Header"
async def process_request(self, request):
raw_value = await self.extract_value_from_header_by_key(request)
return self.process_value(raw_value)
def process_value(self, value):
# Custom processing logic
pass
# Response enrichment
class EnrichingPlugin(Plugin):
key = "X-Data"
async def enrich_response(self, response):
data = context.get(self.key)
if data and hasattr(response, 'headers'):
response.headers["X-Processed"] = str(data)import pytest
from starlette.requests import Request
from starlette_context import request_cycle_context
async def test_custom_plugin():
plugin = CustomHeaderPlugin()
# Mock request with header
scope = {
"type": "http",
"headers": [(b"x-custom-data", b'{"key": "value"}')]
}
request = Request(scope)
# Test plugin processing
result = await plugin.process_request(request)
assert result == {"key": "value"}
# Test with context
with request_cycle_context({"X-Custom-Data": result}):
# Test response enrichment
response = Response()
await plugin.enrich_response(response)
assert "X-Processed-Data" in response.headersInstall with Tessl CLI
npx tessl i tessl/pypi-starlette-context