0
# Decorators and Hooks
1
2
Marshmallow provides decorators for custom data transformation and validation logic that execute at specific points during serialization and deserialization. These hooks enable pre-processing, post-processing, and custom validation workflows.
3
4
## Capabilities
5
6
### Processing Decorators
7
8
Decorators that register methods to execute before and after serialization/deserialization operations.
9
10
```python { .api }
11
def pre_dump(pass_collection=False):
12
"""
13
Register a method to invoke before serializing objects.
14
15
Parameters:
16
- pass_collection: bool, if True, the entire collection is passed to the method
17
instead of individual items when many=True
18
19
Usage: Decorate schema methods that should run before dump/dumps
20
"""
21
22
def post_dump(pass_collection=False, pass_original=False):
23
"""
24
Register a method to invoke after serializing objects.
25
26
Parameters:
27
- pass_collection: bool, if True, the entire collection is passed to the method
28
instead of individual items when many=True
29
- pass_original: bool, if True, the original object is passed as an additional
30
argument to the decorated method
31
32
Usage: Decorate schema methods that should run after dump/dumps
33
"""
34
35
def pre_load(pass_collection=False):
36
"""
37
Register a method to invoke before deserializing data.
38
39
Parameters:
40
- pass_collection: bool, if True, the entire collection is passed to the method
41
instead of individual items when many=True
42
43
Usage: Decorate schema methods that should run before load/loads
44
"""
45
46
def post_load(pass_collection=False, pass_original=False):
47
"""
48
Register a method to invoke after deserializing data.
49
50
Parameters:
51
- pass_collection: bool, if True, the entire collection is passed to the method
52
instead of individual items when many=True
53
- pass_original: bool, if True, the original input data is passed as an additional
54
argument to the decorated method
55
56
Usage: Decorate schema methods that should run after load/loads
57
"""
58
```
59
60
### Validation Decorators
61
62
Decorators for registering custom validation methods.
63
64
```python { .api }
65
def validates(*field_names):
66
"""
67
Register a validator method for specific field(s).
68
69
Parameters:
70
- field_names: str, names of fields to validate (positional arguments)
71
72
Usage: Decorate schema methods that validate specific fields
73
"""
74
75
def validates_schema(pass_collection=False, pass_original=False, skip_on_field_errors=True):
76
"""
77
Register a schema-level validator method.
78
79
Parameters:
80
- pass_collection: bool, if True, the entire collection is passed when many=True
81
- pass_original: bool, if True, the original input data is passed as additional argument
82
- skip_on_field_errors: bool, if True, skip schema validation when field errors exist
83
84
Usage: Decorate schema methods that perform cross-field validation
85
"""
86
```
87
88
## Usage Examples
89
90
### Pre-processing Hooks
91
92
```python
93
from marshmallow import Schema, fields, pre_dump, pre_load
94
95
class UserSchema(Schema):
96
username = fields.Str()
97
email = fields.Email()
98
full_name = fields.Str()
99
100
@pre_dump
101
def process_user(self, user, **kwargs):
102
"""Process user object before serialization."""
103
# Add computed fields
104
if hasattr(user, 'first_name') and hasattr(user, 'last_name'):
105
user.full_name = f"{user.first_name} {user.last_name}"
106
return user
107
108
@pre_load
109
def clean_input(self, data, **kwargs):
110
"""Clean input data before deserialization."""
111
# Strip whitespace from string fields
112
for key, value in data.items():
113
if isinstance(value, str):
114
data[key] = value.strip()
115
116
# Normalize email to lowercase
117
if 'email' in data:
118
data['email'] = data['email'].lower()
119
120
return data
121
```
122
123
### Post-processing Hooks
124
125
```python
126
from marshmallow import Schema, fields, post_dump, post_load
127
128
class ArticleSchema(Schema):
129
title = fields.Str()
130
content = fields.Str()
131
author_id = fields.Int()
132
created_at = fields.DateTime()
133
134
@post_dump
135
def add_metadata(self, data, original_data, **kwargs):
136
"""Add metadata after serialization."""
137
# Add computed fields to serialized data
138
data['word_count'] = len(data.get('content', '').split())
139
data['reading_time'] = max(1, data['word_count'] // 200) # ~200 words/min
140
return data
141
142
@post_load
143
def create_object(self, data, **kwargs):
144
"""Convert deserialized data to object."""
145
# Return a custom object instead of dictionary
146
return Article(**data)
147
148
class Article:
149
def __init__(self, **kwargs):
150
for key, value in kwargs.items():
151
setattr(self, key, value)
152
```
153
154
### Field Validation Hooks
155
156
```python
157
from marshmallow import Schema, fields, validates, ValidationError
158
159
class ProductSchema(Schema):
160
name = fields.Str()
161
price = fields.Decimal()
162
category = fields.Str()
163
discount_percent = fields.Float()
164
165
@validates('name')
166
def validate_name(self, value):
167
"""Validate product name."""
168
if len(value.strip()) < 2:
169
raise ValidationError('Product name must be at least 2 characters.')
170
171
# Check for forbidden words
172
forbidden = ['spam', 'fake', 'scam']
173
if any(word in value.lower() for word in forbidden):
174
raise ValidationError('Product name contains forbidden words.')
175
176
@validates('price')
177
def validate_price(self, value):
178
"""Validate product price."""
179
if value <= 0:
180
raise ValidationError('Price must be positive.')
181
182
if value > 10000:
183
raise ValidationError('Price cannot exceed $10,000.')
184
185
@validates('discount_percent')
186
def validate_discount(self, value):
187
"""Validate discount percentage."""
188
if not (0 <= value <= 100):
189
raise ValidationError('Discount must be between 0 and 100 percent.')
190
```
191
192
### Schema-Level Validation
193
194
```python
195
from marshmallow import Schema, fields, validates_schema, ValidationError
196
197
class OrderSchema(Schema):
198
items = fields.List(fields.Dict())
199
shipping_cost = fields.Decimal()
200
total_cost = fields.Decimal()
201
coupon_code = fields.Str(allow_none=True)
202
discount_amount = fields.Decimal(load_default=0)
203
204
@validates_schema
205
def validate_order_totals(self, data, **kwargs):
206
"""Validate order calculations."""
207
items = data.get('items', [])
208
if not items:
209
raise ValidationError('Order must contain at least one item.')
210
211
# Calculate expected total
212
item_total = sum(item.get('price', 0) * item.get('quantity', 0) for item in items)
213
shipping = data.get('shipping_cost', 0)
214
discount = data.get('discount_amount', 0)
215
expected_total = item_total + shipping - discount
216
217
actual_total = data.get('total_cost')
218
if abs(expected_total - actual_total) > 0.01: # Allow for rounding
219
raise ValidationError('Total cost calculation is incorrect.')
220
221
@validates_schema
222
def validate_coupon(self, data, **kwargs):
223
"""Validate coupon usage."""
224
coupon = data.get('coupon_code')
225
discount = data.get('discount_amount', 0)
226
227
if coupon and discount == 0:
228
raise ValidationError('Coupon code provided but no discount applied.')
229
230
if not coupon and discount > 0:
231
raise ValidationError('Discount applied without coupon code.')
232
```
233
234
### Collection Processing
235
236
```python
237
from marshmallow import Schema, fields, pre_dump, post_load
238
239
class BatchProcessingSchema(Schema):
240
items = fields.List(fields.Dict())
241
batch_id = fields.Str()
242
processed_at = fields.DateTime()
243
244
@pre_dump(pass_collection=True)
245
def add_batch_metadata(self, data, **kwargs):
246
"""Add metadata to entire batch before serialization."""
247
if isinstance(data, list):
248
# Processing a collection of batches
249
for batch in data:
250
batch['item_count'] = len(batch.get('items', []))
251
else:
252
# Processing a single batch
253
data['item_count'] = len(data.get('items', []))
254
return data
255
256
@post_load(pass_collection=True)
257
def validate_batch_limits(self, data, **kwargs):
258
"""Validate batch size limits after loading."""
259
if isinstance(data, list):
260
total_items = sum(len(batch.get('items', [])) for batch in data)
261
if total_items > 1000:
262
raise ValidationError('Batch processing limited to 1000 items total.')
263
return data
264
```
265
266
### Advanced Hook Patterns
267
268
```python
269
class AuditableSchema(Schema):
270
"""Schema with automatic audit trail."""
271
272
created_by = fields.Str()
273
created_at = fields.DateTime()
274
updated_by = fields.Str()
275
updated_at = fields.DateTime()
276
277
@pre_load
278
def set_audit_fields(self, data, **kwargs):
279
"""Set audit fields based on context."""
280
# Access context from schema instance
281
context = self.context or {}
282
current_user = context.get('current_user')
283
284
if current_user:
285
if 'created_by' not in data:
286
data['created_by'] = current_user
287
data['updated_by'] = current_user
288
289
return data
290
291
@post_dump
292
def remove_sensitive_audit_fields(self, data, **kwargs):
293
"""Remove sensitive audit information for external APIs."""
294
context = self.context or {}
295
if context.get('external_api'):
296
data.pop('created_by', None)
297
data.pop('updated_by', None)
298
299
return data
300
301
# Usage with context
302
schema = AuditableSchema(context={'current_user': 'john_doe', 'external_api': False})
303
result = schema.load(input_data)
304
```
305
306
### Error Handling in Hooks
307
308
```python
309
class RobustSchema(Schema):
310
name = fields.Str()
311
data = fields.Dict()
312
313
@pre_load
314
def safe_preprocessing(self, data, **kwargs):
315
"""Safely preprocess data with error handling."""
316
try:
317
# Attempt risky preprocessing
318
if 'data' in data and isinstance(data['data'], str):
319
import json
320
data['data'] = json.loads(data['data'])
321
except (json.JSONDecodeError, TypeError) as e:
322
# Convert to validation error
323
raise ValidationError(f'Invalid JSON in data field: {e}')
324
325
return data
326
327
@validates_schema(skip_on_field_errors=False)
328
def always_validate(self, data, **kwargs):
329
"""Run validation even if field errors exist."""
330
# This runs regardless of field validation failures
331
if not data:
332
raise ValidationError('Schema data cannot be empty.')
333
```