0
# Extensions System
1
2
Schema and field-level extensions for adding custom functionality like validation, caching, security, and monitoring to GraphQL operations. The extension system provides hooks into the GraphQL execution lifecycle for custom logic implementation.
3
4
## Capabilities
5
6
### Schema Extension Base Class
7
8
Base class for creating custom schema-level extensions.
9
10
```python { .api }
11
class SchemaExtension:
12
"""Base class for schema-level extensions."""
13
14
def on_request_start(self) -> None:
15
"""Called when a GraphQL request starts."""
16
17
def on_request_end(self) -> None:
18
"""Called when a GraphQL request ends."""
19
20
def on_validation_start(self) -> None:
21
"""Called when query validation starts."""
22
23
def on_validation_end(self) -> None:
24
"""Called when query validation ends."""
25
26
def on_parsing_start(self) -> None:
27
"""Called when query parsing starts."""
28
29
def on_parsing_end(self) -> None:
30
"""Called when query parsing ends."""
31
32
def on_executing_start(self) -> None:
33
"""Called when query execution starts."""
34
35
def on_executing_end(self) -> None:
36
"""Called when query execution ends."""
37
38
def get_results(self) -> Dict[str, Any]:
39
"""Return extension results to include in response."""
40
```
41
42
**Usage Example:**
43
44
```python
45
class TimingExtension(strawberry.extensions.SchemaExtension):
46
def __init__(self):
47
self.start_time = None
48
self.parsing_time = None
49
self.validation_time = None
50
self.execution_time = None
51
52
def on_request_start(self):
53
self.start_time = time.time()
54
55
def on_parsing_start(self):
56
self.parsing_start = time.time()
57
58
def on_parsing_end(self):
59
self.parsing_time = time.time() - self.parsing_start
60
61
def on_validation_start(self):
62
self.validation_start = time.time()
63
64
def on_validation_end(self):
65
self.validation_time = time.time() - self.validation_start
66
67
def on_executing_start(self):
68
self.execution_start = time.time()
69
70
def on_executing_end(self):
71
self.execution_time = time.time() - self.execution_start
72
73
def get_results(self):
74
return {
75
"timing": {
76
"parsing": f"{self.parsing_time:.4f}s",
77
"validation": f"{self.validation_time:.4f}s",
78
"execution": f"{self.execution_time:.4f}s"
79
}
80
}
81
82
# Use extension
83
schema = strawberry.Schema(
84
query=Query,
85
extensions=[TimingExtension()]
86
)
87
```
88
89
### Built-in Schema Extensions
90
91
#### Query Depth Limiter
92
93
Limits the depth of GraphQL queries to prevent deeply nested queries.
94
95
```python { .api }
96
class QueryDepthLimiter(SchemaExtension):
97
"""Limits the maximum depth of GraphQL queries."""
98
99
def __init__(
100
self,
101
max_depth: int,
102
callback: Callable = None,
103
ignore: List[IgnoreContext] = None
104
):
105
"""
106
Initialize query depth limiter.
107
108
Args:
109
max_depth: Maximum allowed query depth
110
callback: Function called when depth exceeded
111
ignore: List of fields/types to ignore in depth calculation
112
"""
113
114
class IgnoreContext:
115
"""Context for ignoring specific fields in depth calculation."""
116
117
def __init__(
118
self,
119
field_name: str = None,
120
type_name: str = None
121
): ...
122
```
123
124
**Usage Example:**
125
126
```python
127
from strawberry.extensions import QueryDepthLimiter, IgnoreContext
128
129
schema = strawberry.Schema(
130
query=Query,
131
extensions=[
132
QueryDepthLimiter(
133
max_depth=10,
134
ignore=[
135
IgnoreContext(field_name="introspection"),
136
IgnoreContext(type_name="User", field_name="friends")
137
]
138
)
139
]
140
)
141
142
# This query would be rejected if depth > 10
143
"""
144
query DeepQuery {
145
user {
146
posts {
147
comments {
148
author {
149
posts {
150
comments {
151
# ... continues deeply
152
}
153
}
154
}
155
}
156
}
157
}
158
}
159
"""
160
```
161
162
#### Validation Cache
163
164
Caches GraphQL query validation results for improved performance.
165
166
```python { .api }
167
class ValidationCache(SchemaExtension):
168
"""Caches GraphQL query validation results."""
169
170
def __init__(
171
self,
172
maxsize: int = 128,
173
ttl: int = None
174
):
175
"""
176
Initialize validation cache.
177
178
Args:
179
maxsize: Maximum number of cached validations
180
ttl: Time-to-live for cache entries in seconds
181
"""
182
```
183
184
**Usage Example:**
185
186
```python
187
from strawberry.extensions import ValidationCache
188
189
schema = strawberry.Schema(
190
query=Query,
191
extensions=[
192
ValidationCache(maxsize=256, ttl=3600) # Cache for 1 hour
193
]
194
)
195
```
196
197
#### Parser Cache
198
199
Caches GraphQL query parsing results.
200
201
```python { .api }
202
class ParserCache(SchemaExtension):
203
"""Caches GraphQL query parsing results."""
204
205
def __init__(
206
self,
207
maxsize: int = 128,
208
ttl: int = None
209
):
210
"""
211
Initialize parser cache.
212
213
Args:
214
maxsize: Maximum number of cached parse results
215
ttl: Time-to-live for cache entries in seconds
216
"""
217
```
218
219
#### Disable Introspection
220
221
Disables GraphQL schema introspection for security.
222
223
```python { .api }
224
class DisableIntrospection(SchemaExtension):
225
"""Disables GraphQL schema introspection queries."""
226
227
def __init__(self): ...
228
```
229
230
**Usage Example:**
231
232
```python
233
from strawberry.extensions import DisableIntrospection
234
235
# Production schema without introspection
236
production_schema = strawberry.Schema(
237
query=Query,
238
extensions=[DisableIntrospection()]
239
)
240
```
241
242
#### Disable Validation
243
244
Disables GraphQL query validation (use with caution).
245
246
```python { .api }
247
class DisableValidation(SchemaExtension):
248
"""Disables GraphQL query validation."""
249
250
def __init__(self): ...
251
```
252
253
#### Mask Errors
254
255
Masks detailed error information in production environments.
256
257
```python { .api }
258
class MaskErrors(SchemaExtension):
259
"""Masks error details in GraphQL responses."""
260
261
def __init__(
262
self,
263
should_mask_error: Callable = None,
264
error_message: str = "Unexpected error."
265
):
266
"""
267
Initialize error masking.
268
269
Args:
270
should_mask_error: Function to determine if error should be masked
271
error_message: Generic error message to show
272
"""
273
```
274
275
**Usage Example:**
276
277
```python
278
from strawberry.extensions import MaskErrors
279
280
def should_mask_error(error):
281
# Don't mask validation errors, mask internal errors
282
return not isinstance(error.original_error, ValidationError)
283
284
schema = strawberry.Schema(
285
query=Query,
286
extensions=[
287
MaskErrors(
288
should_mask_error=should_mask_error,
289
error_message="An error occurred while processing your request."
290
)
291
]
292
)
293
```
294
295
#### Rate Limiting Extensions
296
297
Limit the number of queries or query complexity.
298
299
```python { .api }
300
class MaxAliasesLimiter(SchemaExtension):
301
"""Limits the number of aliases in a GraphQL query."""
302
303
def __init__(self, max_alias_count: int):
304
"""
305
Initialize alias limiter.
306
307
Args:
308
max_alias_count: Maximum number of aliases allowed
309
"""
310
311
class MaxTokensLimiter(SchemaExtension):
312
"""Limits the number of tokens in a GraphQL query."""
313
314
def __init__(self, max_token_count: int):
315
"""
316
Initialize token limiter.
317
318
Args:
319
max_token_count: Maximum number of tokens allowed
320
"""
321
```
322
323
**Usage Example:**
324
325
```python
326
from strawberry.extensions import MaxAliasesLimiter, MaxTokensLimiter
327
328
schema = strawberry.Schema(
329
query=Query,
330
extensions=[
331
MaxAliasesLimiter(max_alias_count=15),
332
MaxTokensLimiter(max_token_count=1000)
333
]
334
)
335
```
336
337
#### Add Validation Rules
338
339
Adds custom GraphQL validation rules to the schema.
340
341
```python { .api }
342
class AddValidationRules(SchemaExtension):
343
"""Adds custom GraphQL validation rules."""
344
345
def __init__(self, validation_rules: List[Callable]):
346
"""
347
Initialize validation rules extension.
348
349
Args:
350
validation_rules: List of validation rule functions
351
"""
352
```
353
354
**Usage Example:**
355
356
```python
357
from strawberry.extensions import AddValidationRules
358
from graphql.validation import ValidationRule
359
360
class CustomValidationRule(ValidationRule):
361
def enter_field(self, node, *_):
362
# Custom validation logic
363
if node.name.value.startswith("admin_") and not self.context.user.is_admin:
364
self.report_error("Admin fields require admin access")
365
366
schema = strawberry.Schema(
367
query=Query,
368
extensions=[
369
AddValidationRules([CustomValidationRule])
370
]
371
)
372
```
373
374
### Field Extension Base Class
375
376
Base class for creating field-level extensions.
377
378
```python { .api }
379
class FieldExtension:
380
"""Base class for field-level extensions."""
381
382
def apply(self, field: StrawberryField) -> StrawberryField:
383
"""
384
Apply extension logic to a field.
385
386
Args:
387
field: Field to modify
388
389
Returns:
390
Modified field
391
"""
392
393
def resolve(
394
self,
395
next_: Callable,
396
source: Any,
397
info: Info,
398
**kwargs
399
) -> Any:
400
"""
401
Custom field resolution logic.
402
403
Args:
404
next_: Next resolver in the chain
405
source: Parent object
406
info: GraphQL execution info
407
**kwargs: Field arguments
408
409
Returns:
410
Field resolution result
411
"""
412
```
413
414
**Usage Example:**
415
416
```python
417
class CacheFieldExtension(strawberry.extensions.FieldExtension):
418
def __init__(self, ttl: int = 300):
419
self.ttl = ttl
420
421
def resolve(self, next_, source, info, **kwargs):
422
# Create cache key from field path and arguments
423
cache_key = f"{info.path}:{hash(str(kwargs))}"
424
425
# Try to get from cache
426
cached_result = get_from_cache(cache_key)
427
if cached_result is not None:
428
return cached_result
429
430
# Execute resolver
431
result = next_(source, info, **kwargs)
432
433
# Cache result
434
set_cache(cache_key, result, ttl=self.ttl)
435
436
return result
437
438
@strawberry.type
439
class Query:
440
@strawberry.field(extensions=[CacheFieldExtension(ttl=600)])
441
def expensive_computation(self, value: int) -> int:
442
# Expensive computation that benefits from caching
443
return perform_heavy_calculation(value)
444
```
445
446
### Built-in Field Extensions
447
448
#### Input Mutation Extension
449
450
Provides input mutation pattern support for mutations.
451
452
```python { .api }
453
class InputMutationExtension(FieldExtension):
454
"""Extension for input mutation pattern support."""
455
456
def apply(self, field: StrawberryField) -> StrawberryField: ...
457
```
458
459
**Usage Example:**
460
461
```python
462
from strawberry.field_extensions import InputMutationExtension
463
464
@strawberry.input
465
class CreateUserInput:
466
name: str
467
email: str
468
469
@strawberry.type
470
class CreateUserPayload:
471
user: User
472
success: bool
473
474
@strawberry.type
475
class Mutation:
476
@strawberry.mutation(extensions=[InputMutationExtension()])
477
def create_user(self, input: CreateUserInput) -> CreateUserPayload:
478
user = User(name=input.name, email=input.email)
479
save_user(user)
480
return CreateUserPayload(user=user, success=True)
481
```
482
483
## Custom Extension Examples
484
485
### Logging Extension
486
487
```python
488
import logging
489
490
class LoggingExtension(strawberry.extensions.SchemaExtension):
491
def __init__(self):
492
self.logger = logging.getLogger("graphql")
493
494
def on_request_start(self):
495
self.logger.info("GraphQL request started")
496
497
def on_request_end(self):
498
self.logger.info("GraphQL request completed")
499
500
def on_validation_start(self):
501
self.logger.debug("Query validation started")
502
503
def on_validation_end(self):
504
self.logger.debug("Query validation completed")
505
```
506
507
### Metrics Extension
508
509
```python
510
class MetricsExtension(strawberry.extensions.SchemaExtension):
511
def __init__(self, metrics_client):
512
self.metrics = metrics_client
513
self.request_start_time = None
514
515
def on_request_start(self):
516
self.request_start_time = time.time()
517
self.metrics.increment("graphql.requests.started")
518
519
def on_request_end(self):
520
duration = time.time() - self.request_start_time
521
self.metrics.histogram("graphql.request.duration", duration)
522
self.metrics.increment("graphql.requests.completed")
523
524
def on_validation_end(self):
525
self.metrics.increment("graphql.validation.completed")
526
```
527
528
### Authentication Extension
529
530
```python
531
class AuthenticationExtension(strawberry.extensions.SchemaExtension):
532
def on_request_start(self):
533
# Validate authentication token
534
token = self.execution_context.context_value.get("auth_token")
535
if token:
536
user = validate_token(token)
537
self.execution_context.context_value["user"] = user
538
else:
539
self.execution_context.context_value["user"] = None
540
```
541
542
### Complex Field Extension
543
544
```python
545
class AuditFieldExtension(strawberry.extensions.FieldExtension):
546
def resolve(self, next_, source, info, **kwargs):
547
# Log field access
548
user = info.context.user
549
field_name = info.field_name
550
551
audit_log.info(
552
f"User {user.id if user else 'anonymous'} accessed field {field_name}",
553
extra={
554
"user_id": user.id if user else None,
555
"field_name": field_name,
556
"query_path": info.path
557
}
558
)
559
560
# Execute resolver
561
result = next_(source, info, **kwargs)
562
563
return result
564
565
@strawberry.type
566
class User:
567
id: strawberry.ID
568
name: str
569
570
@strawberry.field(extensions=[AuditFieldExtension()])
571
def sensitive_data(self) -> str:
572
return self._sensitive_information
573
```
574
575
## Extension Configuration
576
577
### Multiple Extensions
578
579
```python
580
# Combine multiple extensions
581
schema = strawberry.Schema(
582
query=Query,
583
extensions=[
584
TimingExtension(),
585
LoggingExtension(),
586
QueryDepthLimiter(max_depth=15),
587
ValidationCache(maxsize=500),
588
ParserCache(maxsize=100),
589
DisableIntrospection(), # For production
590
MaskErrors() # For production
591
]
592
)
593
```
594
595
### Conditional Extensions
596
597
```python
598
import os
599
600
def create_schema():
601
extensions = [
602
QueryDepthLimiter(max_depth=20),
603
ValidationCache(),
604
ParserCache()
605
]
606
607
# Add production-only extensions
608
if os.getenv("ENVIRONMENT") == "production":
609
extensions.extend([
610
DisableIntrospection(),
611
MaskErrors(),
612
MaxTokensLimiter(max_token_count=5000)
613
])
614
615
# Add development-only extensions
616
if os.getenv("ENVIRONMENT") == "development":
617
extensions.extend([
618
TimingExtension(),
619
LoggingExtension()
620
])
621
622
return strawberry.Schema(
623
query=Query,
624
extensions=extensions
625
)
626
```