A lightweight library for converting complex datatypes to and from native Python datatypes
—
Marshmallow provides decorators for custom data transformation and validation logic that execute at specific points during serialization and deserialization. These hooks enable pre-processing, post-processing, and custom validation workflows.
Decorators that register methods to execute before and after serialization/deserialization operations.
def pre_dump(pass_collection=False):
"""
Register a method to invoke before serializing objects.
Parameters:
- pass_collection: bool, if True, the entire collection is passed to the method
instead of individual items when many=True
Usage: Decorate schema methods that should run before dump/dumps
"""
def post_dump(pass_collection=False, pass_original=False):
"""
Register a method to invoke after serializing objects.
Parameters:
- pass_collection: bool, if True, the entire collection is passed to the method
instead of individual items when many=True
- pass_original: bool, if True, the original object is passed as an additional
argument to the decorated method
Usage: Decorate schema methods that should run after dump/dumps
"""
def pre_load(pass_collection=False):
"""
Register a method to invoke before deserializing data.
Parameters:
- pass_collection: bool, if True, the entire collection is passed to the method
instead of individual items when many=True
Usage: Decorate schema methods that should run before load/loads
"""
def post_load(pass_collection=False, pass_original=False):
"""
Register a method to invoke after deserializing data.
Parameters:
- pass_collection: bool, if True, the entire collection is passed to the method
instead of individual items when many=True
- pass_original: bool, if True, the original input data is passed as an additional
argument to the decorated method
Usage: Decorate schema methods that should run after load/loads
"""Decorators for registering custom validation methods.
def validates(*field_names):
"""
Register a validator method for specific field(s).
Parameters:
- field_names: str, names of fields to validate (positional arguments)
Usage: Decorate schema methods that validate specific fields
"""
def validates_schema(pass_collection=False, pass_original=False, skip_on_field_errors=True):
"""
Register a schema-level validator method.
Parameters:
- pass_collection: bool, if True, the entire collection is passed when many=True
- pass_original: bool, if True, the original input data is passed as additional argument
- skip_on_field_errors: bool, if True, skip schema validation when field errors exist
Usage: Decorate schema methods that perform cross-field validation
"""from marshmallow import Schema, fields, pre_dump, pre_load
class UserSchema(Schema):
username = fields.Str()
email = fields.Email()
full_name = fields.Str()
@pre_dump
def process_user(self, user, **kwargs):
"""Process user object before serialization."""
# Add computed fields
if hasattr(user, 'first_name') and hasattr(user, 'last_name'):
user.full_name = f"{user.first_name} {user.last_name}"
return user
@pre_load
def clean_input(self, data, **kwargs):
"""Clean input data before deserialization."""
# Strip whitespace from string fields
for key, value in data.items():
if isinstance(value, str):
data[key] = value.strip()
# Normalize email to lowercase
if 'email' in data:
data['email'] = data['email'].lower()
return datafrom marshmallow import Schema, fields, post_dump, post_load
class ArticleSchema(Schema):
title = fields.Str()
content = fields.Str()
author_id = fields.Int()
created_at = fields.DateTime()
@post_dump
def add_metadata(self, data, original_data, **kwargs):
"""Add metadata after serialization."""
# Add computed fields to serialized data
data['word_count'] = len(data.get('content', '').split())
data['reading_time'] = max(1, data['word_count'] // 200) # ~200 words/min
return data
@post_load
def create_object(self, data, **kwargs):
"""Convert deserialized data to object."""
# Return a custom object instead of dictionary
return Article(**data)
class Article:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)from marshmallow import Schema, fields, validates, ValidationError
class ProductSchema(Schema):
name = fields.Str()
price = fields.Decimal()
category = fields.Str()
discount_percent = fields.Float()
@validates('name')
def validate_name(self, value):
"""Validate product name."""
if len(value.strip()) < 2:
raise ValidationError('Product name must be at least 2 characters.')
# Check for forbidden words
forbidden = ['spam', 'fake', 'scam']
if any(word in value.lower() for word in forbidden):
raise ValidationError('Product name contains forbidden words.')
@validates('price')
def validate_price(self, value):
"""Validate product price."""
if value <= 0:
raise ValidationError('Price must be positive.')
if value > 10000:
raise ValidationError('Price cannot exceed $10,000.')
@validates('discount_percent')
def validate_discount(self, value):
"""Validate discount percentage."""
if not (0 <= value <= 100):
raise ValidationError('Discount must be between 0 and 100 percent.')from marshmallow import Schema, fields, validates_schema, ValidationError
class OrderSchema(Schema):
items = fields.List(fields.Dict())
shipping_cost = fields.Decimal()
total_cost = fields.Decimal()
coupon_code = fields.Str(allow_none=True)
discount_amount = fields.Decimal(load_default=0)
@validates_schema
def validate_order_totals(self, data, **kwargs):
"""Validate order calculations."""
items = data.get('items', [])
if not items:
raise ValidationError('Order must contain at least one item.')
# Calculate expected total
item_total = sum(item.get('price', 0) * item.get('quantity', 0) for item in items)
shipping = data.get('shipping_cost', 0)
discount = data.get('discount_amount', 0)
expected_total = item_total + shipping - discount
actual_total = data.get('total_cost')
if abs(expected_total - actual_total) > 0.01: # Allow for rounding
raise ValidationError('Total cost calculation is incorrect.')
@validates_schema
def validate_coupon(self, data, **kwargs):
"""Validate coupon usage."""
coupon = data.get('coupon_code')
discount = data.get('discount_amount', 0)
if coupon and discount == 0:
raise ValidationError('Coupon code provided but no discount applied.')
if not coupon and discount > 0:
raise ValidationError('Discount applied without coupon code.')from marshmallow import Schema, fields, pre_dump, post_load
class BatchProcessingSchema(Schema):
items = fields.List(fields.Dict())
batch_id = fields.Str()
processed_at = fields.DateTime()
@pre_dump(pass_collection=True)
def add_batch_metadata(self, data, **kwargs):
"""Add metadata to entire batch before serialization."""
if isinstance(data, list):
# Processing a collection of batches
for batch in data:
batch['item_count'] = len(batch.get('items', []))
else:
# Processing a single batch
data['item_count'] = len(data.get('items', []))
return data
@post_load(pass_collection=True)
def validate_batch_limits(self, data, **kwargs):
"""Validate batch size limits after loading."""
if isinstance(data, list):
total_items = sum(len(batch.get('items', [])) for batch in data)
if total_items > 1000:
raise ValidationError('Batch processing limited to 1000 items total.')
return dataclass AuditableSchema(Schema):
"""Schema with automatic audit trail."""
created_by = fields.Str()
created_at = fields.DateTime()
updated_by = fields.Str()
updated_at = fields.DateTime()
@pre_load
def set_audit_fields(self, data, **kwargs):
"""Set audit fields based on context."""
# Access context from schema instance
context = self.context or {}
current_user = context.get('current_user')
if current_user:
if 'created_by' not in data:
data['created_by'] = current_user
data['updated_by'] = current_user
return data
@post_dump
def remove_sensitive_audit_fields(self, data, **kwargs):
"""Remove sensitive audit information for external APIs."""
context = self.context or {}
if context.get('external_api'):
data.pop('created_by', None)
data.pop('updated_by', None)
return data
# Usage with context
schema = AuditableSchema(context={'current_user': 'john_doe', 'external_api': False})
result = schema.load(input_data)class RobustSchema(Schema):
name = fields.Str()
data = fields.Dict()
@pre_load
def safe_preprocessing(self, data, **kwargs):
"""Safely preprocess data with error handling."""
try:
# Attempt risky preprocessing
if 'data' in data and isinstance(data['data'], str):
import json
data['data'] = json.loads(data['data'])
except (json.JSONDecodeError, TypeError) as e:
# Convert to validation error
raise ValidationError(f'Invalid JSON in data field: {e}')
return data
@validates_schema(skip_on_field_errors=False)
def always_validate(self, data, **kwargs):
"""Run validation even if field errors exist."""
# This runs regardless of field validation failures
if not data:
raise ValidationError('Schema data cannot be empty.')Install with Tessl CLI
npx tessl i tessl/pypi-marshmallow