A library for creating GraphQL APIs using dataclasses and type annotations with extensive framework integration support.
Schema and field-level extensions for adding custom functionality like validation, caching, security, and monitoring to GraphQL operations. The extension system provides hooks into the GraphQL execution lifecycle for custom logic implementation.
Base class for creating custom schema-level extensions.
class SchemaExtension:
"""Base class for schema-level extensions."""
def on_request_start(self) -> None:
"""Called when a GraphQL request starts."""
def on_request_end(self) -> None:
"""Called when a GraphQL request ends."""
def on_validation_start(self) -> None:
"""Called when query validation starts."""
def on_validation_end(self) -> None:
"""Called when query validation ends."""
def on_parsing_start(self) -> None:
"""Called when query parsing starts."""
def on_parsing_end(self) -> None:
"""Called when query parsing ends."""
def on_executing_start(self) -> None:
"""Called when query execution starts."""
def on_executing_end(self) -> None:
"""Called when query execution ends."""
def get_results(self) -> Dict[str, Any]:
"""Return extension results to include in response."""Usage Example:
class TimingExtension(strawberry.extensions.SchemaExtension):
def __init__(self):
self.start_time = None
self.parsing_time = None
self.validation_time = None
self.execution_time = None
def on_request_start(self):
self.start_time = time.time()
def on_parsing_start(self):
self.parsing_start = time.time()
def on_parsing_end(self):
self.parsing_time = time.time() - self.parsing_start
def on_validation_start(self):
self.validation_start = time.time()
def on_validation_end(self):
self.validation_time = time.time() - self.validation_start
def on_executing_start(self):
self.execution_start = time.time()
def on_executing_end(self):
self.execution_time = time.time() - self.execution_start
def get_results(self):
return {
"timing": {
"parsing": f"{self.parsing_time:.4f}s",
"validation": f"{self.validation_time:.4f}s",
"execution": f"{self.execution_time:.4f}s"
}
}
# Use extension
schema = strawberry.Schema(
query=Query,
extensions=[TimingExtension()]
)Limits the depth of GraphQL queries to prevent deeply nested queries.
class QueryDepthLimiter(SchemaExtension):
"""Limits the maximum depth of GraphQL queries."""
def __init__(
self,
max_depth: int,
callback: Callable = None,
ignore: List[IgnoreContext] = None
):
"""
Initialize query depth limiter.
Args:
max_depth: Maximum allowed query depth
callback: Function called when depth exceeded
ignore: List of fields/types to ignore in depth calculation
"""
class IgnoreContext:
"""Context for ignoring specific fields in depth calculation."""
def __init__(
self,
field_name: str = None,
type_name: str = None
): ...Usage Example:
from strawberry.extensions import QueryDepthLimiter, IgnoreContext
schema = strawberry.Schema(
query=Query,
extensions=[
QueryDepthLimiter(
max_depth=10,
ignore=[
IgnoreContext(field_name="introspection"),
IgnoreContext(type_name="User", field_name="friends")
]
)
]
)
# This query would be rejected if depth > 10
"""
query DeepQuery {
user {
posts {
comments {
author {
posts {
comments {
# ... continues deeply
}
}
}
}
}
}
}
"""Caches GraphQL query validation results for improved performance.
class ValidationCache(SchemaExtension):
"""Caches GraphQL query validation results."""
def __init__(
self,
maxsize: int = 128,
ttl: int = None
):
"""
Initialize validation cache.
Args:
maxsize: Maximum number of cached validations
ttl: Time-to-live for cache entries in seconds
"""Usage Example:
from strawberry.extensions import ValidationCache
schema = strawberry.Schema(
query=Query,
extensions=[
ValidationCache(maxsize=256, ttl=3600) # Cache for 1 hour
]
)Caches GraphQL query parsing results.
class ParserCache(SchemaExtension):
"""Caches GraphQL query parsing results."""
def __init__(
self,
maxsize: int = 128,
ttl: int = None
):
"""
Initialize parser cache.
Args:
maxsize: Maximum number of cached parse results
ttl: Time-to-live for cache entries in seconds
"""Disables GraphQL schema introspection for security.
class DisableIntrospection(SchemaExtension):
"""Disables GraphQL schema introspection queries."""
def __init__(self): ...Usage Example:
from strawberry.extensions import DisableIntrospection
# Production schema without introspection
production_schema = strawberry.Schema(
query=Query,
extensions=[DisableIntrospection()]
)Disables GraphQL query validation (use with caution).
class DisableValidation(SchemaExtension):
"""Disables GraphQL query validation."""
def __init__(self): ...Masks detailed error information in production environments.
class MaskErrors(SchemaExtension):
"""Masks error details in GraphQL responses."""
def __init__(
self,
should_mask_error: Callable = None,
error_message: str = "Unexpected error."
):
"""
Initialize error masking.
Args:
should_mask_error: Function to determine if error should be masked
error_message: Generic error message to show
"""Usage Example:
from strawberry.extensions import MaskErrors
def should_mask_error(error):
# Don't mask validation errors, mask internal errors
return not isinstance(error.original_error, ValidationError)
schema = strawberry.Schema(
query=Query,
extensions=[
MaskErrors(
should_mask_error=should_mask_error,
error_message="An error occurred while processing your request."
)
]
)Limit the number of queries or query complexity.
class MaxAliasesLimiter(SchemaExtension):
"""Limits the number of aliases in a GraphQL query."""
def __init__(self, max_alias_count: int):
"""
Initialize alias limiter.
Args:
max_alias_count: Maximum number of aliases allowed
"""
class MaxTokensLimiter(SchemaExtension):
"""Limits the number of tokens in a GraphQL query."""
def __init__(self, max_token_count: int):
"""
Initialize token limiter.
Args:
max_token_count: Maximum number of tokens allowed
"""Usage Example:
from strawberry.extensions import MaxAliasesLimiter, MaxTokensLimiter
schema = strawberry.Schema(
query=Query,
extensions=[
MaxAliasesLimiter(max_alias_count=15),
MaxTokensLimiter(max_token_count=1000)
]
)Adds custom GraphQL validation rules to the schema.
class AddValidationRules(SchemaExtension):
"""Adds custom GraphQL validation rules."""
def __init__(self, validation_rules: List[Callable]):
"""
Initialize validation rules extension.
Args:
validation_rules: List of validation rule functions
"""Usage Example:
from strawberry.extensions import AddValidationRules
from graphql.validation import ValidationRule
class CustomValidationRule(ValidationRule):
def enter_field(self, node, *_):
# Custom validation logic
if node.name.value.startswith("admin_") and not self.context.user.is_admin:
self.report_error("Admin fields require admin access")
schema = strawberry.Schema(
query=Query,
extensions=[
AddValidationRules([CustomValidationRule])
]
)Base class for creating field-level extensions.
class FieldExtension:
"""Base class for field-level extensions."""
def apply(self, field: StrawberryField) -> StrawberryField:
"""
Apply extension logic to a field.
Args:
field: Field to modify
Returns:
Modified field
"""
def resolve(
self,
next_: Callable,
source: Any,
info: Info,
**kwargs
) -> Any:
"""
Custom field resolution logic.
Args:
next_: Next resolver in the chain
source: Parent object
info: GraphQL execution info
**kwargs: Field arguments
Returns:
Field resolution result
"""Usage Example:
class CacheFieldExtension(strawberry.extensions.FieldExtension):
def __init__(self, ttl: int = 300):
self.ttl = ttl
def resolve(self, next_, source, info, **kwargs):
# Create cache key from field path and arguments
cache_key = f"{info.path}:{hash(str(kwargs))}"
# Try to get from cache
cached_result = get_from_cache(cache_key)
if cached_result is not None:
return cached_result
# Execute resolver
result = next_(source, info, **kwargs)
# Cache result
set_cache(cache_key, result, ttl=self.ttl)
return result
@strawberry.type
class Query:
@strawberry.field(extensions=[CacheFieldExtension(ttl=600)])
def expensive_computation(self, value: int) -> int:
# Expensive computation that benefits from caching
return perform_heavy_calculation(value)Provides input mutation pattern support for mutations.
class InputMutationExtension(FieldExtension):
"""Extension for input mutation pattern support."""
def apply(self, field: StrawberryField) -> StrawberryField: ...Usage Example:
from strawberry.field_extensions import InputMutationExtension
@strawberry.input
class CreateUserInput:
name: str
email: str
@strawberry.type
class CreateUserPayload:
user: User
success: bool
@strawberry.type
class Mutation:
@strawberry.mutation(extensions=[InputMutationExtension()])
def create_user(self, input: CreateUserInput) -> CreateUserPayload:
user = User(name=input.name, email=input.email)
save_user(user)
return CreateUserPayload(user=user, success=True)import logging
class LoggingExtension(strawberry.extensions.SchemaExtension):
def __init__(self):
self.logger = logging.getLogger("graphql")
def on_request_start(self):
self.logger.info("GraphQL request started")
def on_request_end(self):
self.logger.info("GraphQL request completed")
def on_validation_start(self):
self.logger.debug("Query validation started")
def on_validation_end(self):
self.logger.debug("Query validation completed")class MetricsExtension(strawberry.extensions.SchemaExtension):
def __init__(self, metrics_client):
self.metrics = metrics_client
self.request_start_time = None
def on_request_start(self):
self.request_start_time = time.time()
self.metrics.increment("graphql.requests.started")
def on_request_end(self):
duration = time.time() - self.request_start_time
self.metrics.histogram("graphql.request.duration", duration)
self.metrics.increment("graphql.requests.completed")
def on_validation_end(self):
self.metrics.increment("graphql.validation.completed")class AuthenticationExtension(strawberry.extensions.SchemaExtension):
def on_request_start(self):
# Validate authentication token
token = self.execution_context.context_value.get("auth_token")
if token:
user = validate_token(token)
self.execution_context.context_value["user"] = user
else:
self.execution_context.context_value["user"] = Noneclass AuditFieldExtension(strawberry.extensions.FieldExtension):
def resolve(self, next_, source, info, **kwargs):
# Log field access
user = info.context.user
field_name = info.field_name
audit_log.info(
f"User {user.id if user else 'anonymous'} accessed field {field_name}",
extra={
"user_id": user.id if user else None,
"field_name": field_name,
"query_path": info.path
}
)
# Execute resolver
result = next_(source, info, **kwargs)
return result
@strawberry.type
class User:
id: strawberry.ID
name: str
@strawberry.field(extensions=[AuditFieldExtension()])
def sensitive_data(self) -> str:
return self._sensitive_information# Combine multiple extensions
schema = strawberry.Schema(
query=Query,
extensions=[
TimingExtension(),
LoggingExtension(),
QueryDepthLimiter(max_depth=15),
ValidationCache(maxsize=500),
ParserCache(maxsize=100),
DisableIntrospection(), # For production
MaskErrors() # For production
]
)import os
def create_schema():
extensions = [
QueryDepthLimiter(max_depth=20),
ValidationCache(),
ParserCache()
]
# Add production-only extensions
if os.getenv("ENVIRONMENT") == "production":
extensions.extend([
DisableIntrospection(),
MaskErrors(),
MaxTokensLimiter(max_token_count=5000)
])
# Add development-only extensions
if os.getenv("ENVIRONMENT") == "development":
extensions.extend([
TimingExtension(),
LoggingExtension()
])
return strawberry.Schema(
query=Query,
extensions=extensions
)Install with Tessl CLI
npx tessl i tessl/pypi-strawberry-graphql