Efficient tree implementations for Django models providing three different tree algorithms with unified API
—
Helper functions for number base conversion, tree validation, database operations, and other utility functions that support tree operations. These utilities are used internally by the tree implementations but are also available for advanced use cases.
Utility functions and classes for converting between different number bases, primarily used by MP_Node for path encoding.
class NumConv:
"""
Number conversion utility for different bases.
Converts integers to string representations using custom alphabets
and radixes, used for encoding tree paths in MP_Node.
"""
def __init__(self, radix=10, alphabet=BASE85):
"""
Initialize number converter.
Parameters:
radix (int): Base for number conversion (default: 10)
alphabet (str): Character set for encoding (default: BASE85)
"""
self.radix = radix
self.alphabet = alphabet
def int2str(self, num):
"""
Convert integer to string representation.
Parameters:
num (int): Integer to convert
Returns:
str: String representation in configured base
"""
def str2int(self, num):
"""
Convert string representation back to integer.
Parameters:
num (str): String representation to convert
Returns:
int: Integer value
"""def int2str(num, radix=10, alphabet=BASE85):
"""
Convert integer to string representation.
Quick conversion function without creating NumConv instance.
Parameters:
num (int): Integer to convert
radix (int): Base for conversion (default: 10)
alphabet (str): Character set for encoding (default: BASE85)
Returns:
str: String representation in specified base
"""
def str2int(num, radix=10, alphabet=BASE85):
"""
Convert string representation back to integer.
Quick conversion function without creating NumConv instance.
Parameters:
num (str): String representation to convert
radix (int): Base for conversion (default: 10)
alphabet (str): Character set for encoding (default: BASE85)
Returns:
int: Integer value
"""# Base16 alphabet (hexadecimal)
BASE16 = '0123456789ABCDEF'
# RFC4648 Base32 alphabet
BASE32 = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'
# Base32 hex alphabet
BASE32HEX = '0123456789ABCDEFGHIJKLMNOPQRSTUV'
# Base62 alphabet (alphanumeric)
BASE62 = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
# RFC4648 Base64 alphabet
BASE64 = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'
# URL-safe Base64 alphabet
BASE64URL = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_'
# Base85 alphabet (default for MP_Node paths)
BASE85 = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&()*+-;<=>?@^_`{|}~'Functions for checking and repairing tree structure integrity.
@classmethod
def find_problems(cls):
"""
Check tree structure for problems and inconsistencies.
Identifies various types of tree corruption including:
- Orphaned nodes
- Invalid parent references
- Incorrect depth calculations
- Path inconsistencies (MP_Node)
- Boundary violations (NS_Node)
Returns:
list: List of problem descriptions found in tree
"""
@classmethod
def fix_tree(cls, destructive=False):
"""
Fix tree structure inconsistencies.
Attempts to repair common tree problems by recalculating
tree structure fields and removing invalid nodes.
Parameters:
destructive (bool): Whether to remove unfixable nodes
Returns:
dict: Summary of fixes applied
Warning:
destructive=True may result in data loss
"""Database-specific utility functions for SQL generation and database operations.
@classmethod
def get_database_vendor(cls, action='read'):
"""
Get database vendor for SQL generation.
Determines database backend for generating appropriate SQL.
Parameters:
action (str): Type of operation ('read', 'write', 'delete')
Returns:
str: Database vendor ('postgresql', 'mysql', 'sqlite', 'mssql')
"""def sql_concat(*args, **kwargs):
"""
Generate database-specific SQL for string concatenation.
Parameters:
*args: String expressions to concatenate
**kwargs: Additional options (vendor, etc.)
Returns:
str: Database-specific concatenation SQL
"""
def sql_length(field, vendor=None):
"""
Generate database-specific SQL for string length.
Parameters:
field (str): Field name or expression
vendor (str, optional): Database vendor override
Returns:
str: Database-specific length function SQL
"""
def sql_substr(field, pos, length=None, **kwargs):
"""
Generate database-specific SQL for substring extraction.
Parameters:
field (str): Field name or expression
pos (int): Starting position (1-based)
length (int, optional): Length of substring
**kwargs: Additional options (vendor, etc.)
Returns:
str: Database-specific substring SQL
"""def get_result_class(cls):
"""
Determine appropriate result class for tree operations.
Handles inheritance scenarios where operations may return
different model classes than the initiating class.
Parameters:
cls (type): Tree model class
Returns:
type: Appropriate result class for operations
"""Specialized utilities for Nested Sets tree operations.
def merge_deleted_counters(c1, c2):
"""
Merge delete operation counters for nested sets.
Combines statistics from multiple delete operations
for batch processing and reporting.
Parameters:
c1 (dict): First counter dictionary
c2 (dict): Second counter dictionary
Returns:
dict: Merged counter statistics
"""from treebeard.numconv import NumConv, int2str, str2int, BASE62
# Quick conversions
hex_value = int2str(255, radix=16, alphabet='0123456789ABCDEF')
print(hex_value) # 'FF'
decimal_value = str2int('FF', radix=16, alphabet='0123456789ABCDEF')
print(decimal_value) # 255
# Using NumConv class for repeated conversions
converter = NumConv(radix=62, alphabet=BASE62)
# Convert sequence of numbers
numbers = [1, 100, 1000, 10000]
encoded = [converter.int2str(n) for n in numbers]
print(encoded) # ['1', '1C', 'G8', '2bI']
# Convert back
decoded = [converter.str2int(s) for s in encoded]
print(decoded) # [1, 100, 1000, 10000]
# Custom alphabet for specific encoding needs
custom_alphabet = '0123456789ABCDEFGHIJ' # Base 20
converter = NumConv(radix=20, alphabet=custom_alphabet)
result = converter.int2str(399) # 19 * 20 + 19 = 399
print(result) # 'JJ'from myapp.models import Category
# Check for tree problems
problems = Category.find_problems()
if problems:
print("Tree problems found:")
for problem in problems:
print(f" - {problem}")
else:
print("Tree structure is valid")
# Fix tree problems
print("Attempting to fix tree...")
fixes = Category.fix_tree(destructive=False)
print(f"Applied {len(fixes)} fixes")
# For serious corruption, use destructive repair
if problems:
print("Applying destructive fixes...")
fixes = Category.fix_tree(destructive=True)
print("Tree repaired, but some data may have been lost")
# Verify fix was successful
remaining_problems = Category.find_problems()
if not remaining_problems:
print("Tree structure successfully repaired")from treebeard.numconv import BASE62
from treebeard.mp_tree import MP_Node
class CompactCategory(MP_Node):
"""Category with compact path encoding."""
name = models.CharField(max_length=100)
# Use Base62 for more compact paths
alphabet = BASE62
steplen = 3 # Shorter steps for more levels
# Custom path configuration
@classmethod
def get_path_info(cls, path=None):
"""Get path information for debugging."""
if not path:
return None
steps = []
for i in range(0, len(path), cls.steplen):
step = path[i:i + cls.steplen]
value = cls.str2int(step, radix=len(cls.alphabet), alphabet=cls.alphabet)
steps.append({'step': step, 'value': value})
return {
'path': path,
'depth': len(steps),
'steps': steps
}
# Usage
root = CompactCategory.add_root(name='Root')
child = root.add_child(name='Child')
print("Path info:", CompactCategory.get_path_info(child.path))
# Output: {'path': '001002', 'depth': 2, 'steps': [{'step': '001', 'value': 1}, {'step': '002', 'value': 2}]}# Check database vendor
vendor = Category.get_database_vendor()
print(f"Using database: {vendor}")
# Generate database-specific SQL (for advanced use cases)
if vendor == 'postgresql':
# PostgreSQL-specific optimizations
sql = "SELECT * FROM category WHERE path ~ '^001'"
elif vendor == 'mysql':
# MySQL-specific optimizations
sql = "SELECT * FROM category WHERE path LIKE '001%'"
else:
# Generic SQL
sql = "SELECT * FROM category WHERE path LIKE '001%'"
# Use in raw queries
from django.db import connection
with connection.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()def analyze_tree_structure(model_class):
"""Analyze tree structure and provide statistics."""
stats = {
'total_nodes': model_class.objects.count(),
'root_nodes': model_class.get_root_nodes().count(),
'max_depth': 0,
'avg_children': 0,
'leaf_nodes': 0
}
# Calculate depth statistics
depths = model_class.objects.values_list('depth', flat=True)
if depths:
stats['max_depth'] = max(depths)
stats['depth_distribution'] = {}
for depth in depths:
stats['depth_distribution'][depth] = stats['depth_distribution'].get(depth, 0) + 1
# Calculate children statistics
children_counts = []
leaf_count = 0
for node in model_class.objects.all():
child_count = node.get_children_count()
children_counts.append(child_count)
if child_count == 0:
leaf_count += 1
if children_counts:
stats['avg_children'] = sum(children_counts) / len(children_counts)
stats['max_children'] = max(children_counts)
stats['leaf_nodes'] = leaf_count
stats['internal_nodes'] = stats['total_nodes'] - leaf_count
return stats
# Usage
stats = analyze_tree_structure(Category)
print("Tree Analysis:")
print(f" Total nodes: {stats['total_nodes']}")
print(f" Root nodes: {stats['root_nodes']}")
print(f" Max depth: {stats['max_depth']}")
print(f" Leaf nodes: {stats['leaf_nodes']}")
print(f" Average children per node: {stats['avg_children']:.2f}")import time
from django.db import connection
def benchmark_tree_operations(model_class, num_operations=100):
"""Benchmark common tree operations."""
# Reset query count
initial_queries = len(connection.queries)
start_time = time.time()
# Create test tree
root = model_class.add_root(name='Benchmark Root')
# Benchmark child creation
child_start = time.time()
for i in range(num_operations):
root.add_child(name=f'Child {i}')
child_time = time.time() - child_start
# Benchmark tree traversal
traverse_start = time.time()
descendants = root.get_descendants()
list(descendants) # Force evaluation
traverse_time = time.time() - traverse_start
# Benchmark move operations
children = list(root.get_children()[:10]) # Get first 10 children
move_start = time.time()
for i, child in enumerate(children[1:], 1):
child.move(children[0], 'right')
move_time = time.time() - move_start
total_time = time.time() - start_time
total_queries = len(connection.queries) - initial_queries
# Cleanup
root.delete()
return {
'total_time': total_time,
'child_creation_time': child_time,
'traversal_time': traverse_time,
'move_time': move_time,
'total_queries': total_queries,
'avg_time_per_operation': total_time / num_operations,
'queries_per_operation': total_queries / num_operations
}
# Compare implementations
al_stats = benchmark_tree_operations(ALCategory, 100)
mp_stats = benchmark_tree_operations(MPCategory, 100)
ns_stats = benchmark_tree_operations(NSCategory, 100)
print("Performance Comparison:")
print(f"AL_Node: {al_stats['total_time']:.3f}s, {al_stats['total_queries']} queries")
print(f"MP_Node: {mp_stats['total_time']:.3f}s, {mp_stats['total_queries']} queries")
print(f"NS_Node: {ns_stats['total_time']:.3f}s, {ns_stats['total_queries']} queries")def validate_tree_business_rules(model_class):
"""Validate business-specific tree rules."""
errors = []
# Rule: No more than 5 levels deep
deep_nodes = model_class.objects.filter(depth__gt=4)
if deep_nodes.exists():
errors.append(f"Found {deep_nodes.count()} nodes deeper than 5 levels")
# Rule: Root nodes must have specific naming pattern
invalid_roots = model_class.get_root_nodes().exclude(name__regex=r'^[A-Z]')
if invalid_roots.exists():
errors.append("Root nodes must start with capital letter")
# Rule: No orphaned nodes (for AL_Node)
if hasattr(model_class, 'parent'):
orphaned = model_class.objects.filter(
parent__isnull=False,
parent__in=model_class.objects.none()
)
if orphaned.exists():
errors.append(f"Found {orphaned.count()} orphaned nodes")
return errors
# Usage in management command
class Command(BaseCommand):
def handle(self, *args, **options):
errors = validate_tree_business_rules(Category)
if errors:
for error in errors:
self.stdout.write(self.style.ERROR(error))
else:
self.stdout.write(self.style.SUCCESS("All business rules validated"))import json
from django.core.serializers.json import DjangoJSONEncoder
def export_tree_to_json(model_class, root_node=None, include_fields=None):
"""Export tree structure to JSON format."""
def serialize_node(node):
data = {'id': node.pk}
# Include specified fields
if include_fields:
for field in include_fields:
data[field] = getattr(node, field)
else:
# Include all non-tree fields
for field in node._meta.fields:
if field.name not in ['id', 'path', 'depth', 'numchild', 'lft', 'rgt', 'tree_id']:
data[field.name] = field.value_from_object(node)
# Add children recursively
children = node.get_children()
if children:
data['children'] = [serialize_node(child) for child in children]
return data
if root_node:
return serialize_node(root_node)
else:
roots = model_class.get_root_nodes()
return [serialize_node(root) for root in roots]
def import_tree_from_json(model_class, json_data, parent=None):
"""Import tree structure from JSON format."""
created_nodes = []
if isinstance(json_data, str):
json_data = json.loads(json_data)
if not isinstance(json_data, list):
json_data = [json_data]
for node_data in json_data:
children_data = node_data.pop('children', [])
node_data.pop('id', None) # Remove original ID
# Create node
if parent:
node = parent.add_child(**node_data)
else:
node = model_class.add_root(**node_data)
created_nodes.append(node)
# Create children recursively
if children_data:
child_nodes = import_tree_from_json(model_class, children_data, parent=node)
created_nodes.extend(child_nodes)
return created_nodes
# Usage
tree_data = export_tree_to_json(Category, include_fields=['name', 'description', 'active'])
with open('tree_backup.json', 'w') as f:
json.dump(tree_data, f, cls=DjangoJSONEncoder, indent=2)
# Later, restore from backup
with open('tree_backup.json', 'r') as f:
backup_data = json.load(f)
restored_nodes = import_tree_from_json(Category, backup_data)
print(f"Restored {len(restored_nodes)} nodes")These utilities provide the foundation for advanced tree operations, custom validation, performance monitoring, and data migration scenarios.
Install with Tessl CLI
npx tessl i tessl/pypi-django-treebeard