Pydantic Advanced Features - Complete Guide
Advanced Pydantic features enable complex data validation, custom types, and optimized serialization for sophisticated applications.
1. Custom Field Types
from pydantic import BaseModel, validator, Field
from typing import Any, List, Dict
import re
from datetime import datetime
from decimal import Decimal
class PhoneNumber(str):
"""Custom phone number type with validation"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value):
if not isinstance(value, str):
raise TypeError('Phone number must be a string')
# Remove all non-digit characters
digits_only = re.sub(r'\D', '', value)
if len(digits_only) < 10 or len(digits_only) > 15:
raise ValueError('Phone number must be 10-15 digits')
return cls(digits_only)
def __repr__(self):
# Format for display
if len(self) == 10:
return f"({self[:3]}) {self[3:6]}-{self[6:]}"
return self
class Email(str):
"""Custom email type with advanced validation"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value):
if not isinstance(value, str):
raise TypeError('Email must be a string')
# Basic email regex
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, value):
raise ValueError('Invalid email format')
# Additional business rules
local, domain = value.split('@')
if len(local) > 64:
raise ValueError('Email local part too long')
if len(domain) > 253:
raise ValueError('Email domain too long')
return cls(value.lower())
class Currency(Decimal):
"""Custom currency type with precision control"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value):
try:
decimal_value = Decimal(str(value))
# Round to 2 decimal places for currency
rounded = decimal_value.quantize(Decimal('0.01'))
if rounded < 0:
raise ValueError('Currency amount cannot be negative')
return cls(rounded)
except (ValueError, decimal.InvalidOperation):
raise ValueError('Invalid currency amount')
# Usage in models
class UserProfile(BaseModel):
email: Email
phone: PhoneNumber
salary: Currency
class Config:
json_encoders = {
Currency: lambda v: float(v),
PhoneNumber: str,
Email: str
}
2. Advanced Validation Patterns
from pydantic import BaseModel, root_validator, validator
from typing import Optional, Union, List, Dict, Any
from datetime import datetime, date
import json
class ConditionalValidation(BaseModel):
"""Model with complex conditional validation logic"""
user_type: str = Field(..., regex=r'^(admin|user|guest)$')
permissions: List[str] = Field(default_factory=list)
access_level: int = Field(..., ge=1, le=10)
# Conditional fields
admin_key: Optional[str] = None
department: Optional[str] = None
temp_access_expires: Optional[datetime] = None
@root_validator
def validate_user_permissions(cls, values):
user_type = values.get('user_type')
permissions = values.get('permissions', [])
access_level = values.get('access_level')
if user_type == 'admin':
# Admin must have admin_key
if not values.get('admin_key'):
raise ValueError('admin_key is required for admin users')
# Admin must have department
if not values.get('department'):
raise ValueError('department is required for admin users')
# Admin access level must be 7+
if access_level < 7:
raise ValueError('Admin users must have access level 7 or higher')
elif user_type == 'guest':
# Guest must have expiration
if not values.get('temp_access_expires'):
raise ValueError('temp_access_expires is required for guest users')
# Guest access level limited
if access_level > 3:
raise ValueError('Guest users cannot have access level above 3')
# Validate expiration is in future
expires = values.get('temp_access_expires')
if expires and expires <= datetime.utcnow():
raise ValueError('temp_access_expires must be in the future')
# Validate permissions based on access level
max_permissions = {
1: ['read'],
2: ['read', 'write'],
3: ['read', 'write', 'delete'],
4: ['read', 'write', 'delete', 'create'],
5: ['read', 'write', 'delete', 'create', 'modify'],
}
allowed_perms = []
for level in range(1, access_level + 1):
allowed_perms.extend(max_permissions.get(level, []))
invalid_perms = set(permissions) - set(allowed_perms)
if invalid_perms:
raise ValueError(f'Permissions {invalid_perms} not allowed for access level {access_level}')
return values
class DynamicModel(BaseModel):
"""Model that handles dynamic/flexible schemas"""
schema_version: str = Field(..., regex=r'^\d+\.\d+$')
data: Dict[str, Any]
@validator('data')
def validate_data_by_schema(cls, v, values):
schema_version = values.get('schema_version')
if schema_version == '1.0':
return cls._validate_v1_schema(v)
elif schema_version == '2.0':
return cls._validate_v2_schema(v)
else:
raise ValueError(f'Unsupported schema version: {schema_version}')
@classmethod
def _validate_v1_schema(cls, data):
required_fields = ['name', 'age']
for field in required_fields:
if field not in data:
raise ValueError(f'Missing required field: {field}')
if not isinstance(data['age'], int) or data['age'] < 0:
raise ValueError('age must be a positive integer')
return data
@classmethod
def _validate_v2_schema(cls, data):
required_fields = ['full_name', 'birth_date', 'email']
for field in required_fields:
if field not in data:
raise ValueError(f'Missing required field: {field}')
# Validate birth_date format
try:
datetime.strptime(data['birth_date'], '%Y-%m-%d')
except ValueError:
raise ValueError('birth_date must be in YYYY-MM-DD format')
return data
3. Custom Serialization
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
import json
from datetime import datetime
from decimal import Decimal
class SerializationModel(BaseModel):
"""Model with custom serialization logic"""
id: int
name: str
created_at: datetime
tags: List[str]
metadata: Dict[str, Any]
price: Decimal
class Config:
# Custom JSON encoders
json_encoders = {
datetime: lambda v: v.isoformat(),
Decimal: lambda v: float(v),
set: lambda v: list(v)
}
# Field aliases for serialization
fields = {
'created_at': 'createdAt',
'metadata': 'meta'
}
def dict(self, **kwargs) -> Dict[str, Any]:
"""Custom dict method with additional processing"""
data = super().dict(**kwargs)
# Add computed fields
data['display_name'] = self.name.title()
data['tag_count'] = len(self.tags)
# Format price for display
data['formatted_price'] = f"${self.price:.2f}"
return data
def json(self, **kwargs) -> str:
"""Custom JSON serialization"""
return json.dumps(
self.dict(**kwargs),
indent=2,
ensure_ascii=False
)
class APIResponse(BaseModel):
"""Generic API response with flexible data"""
success: bool = True
message: str = "OK"
data: Optional[Any] = None
errors: List[str] = Field(default_factory=list)
timestamp: datetime = Field(default_factory=datetime.utcnow)
@classmethod
def success_response(cls, data: Any = None, message: str = "OK"):
return cls(success=True, message=message, data=data)
@classmethod
def error_response(cls, errors: List[str], message: str = "Error"):
return cls(success=False, message=message, errors=errors)
class Config:
json_encoders = {
datetime: lambda v: v.isoformat()
}
4. Generic Models
from pydantic import BaseModel
from typing import TypeVar, Generic, List, Optional
T = TypeVar('T')
class PaginatedResponse(BaseModel, Generic[T]):
"""Generic paginated response"""
items: List[T]
total: int
page: int = Field(..., ge=1)
per_page: int = Field(..., ge=1, le=100)
has_next: bool
has_prev: bool
@classmethod
def create(
cls,
items: List[T],
total: int,
page: int,
per_page: int
):
return cls(
items=items,
total=total,
page=page,
per_page=per_page,
has_next=page * per_page < total,
has_prev=page > 1
)
class CacheWrapper(BaseModel, Generic[T]):
"""Generic cache wrapper with metadata"""
data: T
cached_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: Optional[datetime] = None
cache_key: str
hit_count: int = 0
def is_expired(self) -> bool:
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def increment_hit(self):
self.hit_count += 1
# Usage with specific types
UserPage = PaginatedResponse[User]
ProductCache = CacheWrapper[Product]
5. Performance Optimization
from pydantic import BaseModel, Field, validator
from pydantic.dataclasses import dataclass
from typing import List, Dict, Set
import orjson
# Use dataclasses for better performance when inheritance isn't needed
@dataclass
class FastDataClass:
"""Pydantic dataclass for better performance"""
id: int
name: str
tags: List[str] = Field(default_factory=list)
class Config:
# Use orjson for faster JSON parsing
json_loads = orjson.loads
json_dumps = orjson.dumps
class OptimizedModel(BaseModel):
"""Model optimized for performance"""
id: int
data: Dict[str, Any]
class Config:
# Disable validation for better performance in trusted environments
validate_assignment = False
# Use orjson for faster JSON operations
json_loads = orjson.loads
json_dumps = lambda v, *, default: orjson.dumps(v, default=default).decode()
# Cache compiled validators
keep_untouched = (cached_property,)
# Optimize field access
allow_reuse = True
# Validation caching for expensive operations
validation_cache: Dict[str, Any] = {}
class CachedValidationModel(BaseModel):
"""Model with cached validation for expensive operations"""
email: str
complex_data: Dict[str, Any]
@validator('email')
def validate_email_cached(cls, v):
cache_key = f"email_validation:{v}"
if cache_key in validation_cache:
if not validation_cache[cache_key]:
raise ValueError('Invalid email (cached)')
return v
# Expensive validation (e.g., external API call)
is_valid = cls._expensive_email_validation(v)
validation_cache[cache_key] = is_valid
if not is_valid:
raise ValueError('Invalid email')
return v
@classmethod
def _expensive_email_validation(cls, email: str) -> bool:
# Simulate expensive validation
import time
time.sleep(0.1) # Simulate API call
return '@' in email and '.' in email.split('@')[1]
6. Advanced Configuration
from pydantic import BaseModel, Field, validator
from typing import Any, Dict, List
import os
class AppConfig(BaseModel):
"""Application configuration with environment variable support"""
# Database settings
database_url: str = Field(..., env='DATABASE_URL')
database_pool_size: int = Field(10, env='DB_POOL_SIZE')
# Redis settings
redis_url: str = Field('redis://localhost:6379', env='REDIS_URL')
redis_ttl: int = Field(3600, env='REDIS_TTL')
# API settings
api_key: str = Field(..., env='API_KEY')
rate_limit: int = Field(100, env='RATE_LIMIT')
# Feature flags
feature_flags: Dict[str, bool] = Field(default_factory=dict)
@validator('database_url')
def validate_database_url(cls, v):
if not v.startswith(('postgresql://', 'mysql://', 'sqlite://')):
raise ValueError('Invalid database URL scheme')
return v
@validator('feature_flags', pre=True)
def parse_feature_flags(cls, v):
if isinstance(v, str):
# Parse from environment variable
flags = {}
for flag in v.split(','):
if '=' in flag:
key, value = flag.split('=', 1)
flags[key.strip()] = value.strip().lower() == 'true'
return flags
return v or {}
class Config:
env_file = '.env'
env_file_encoding = 'utf-8'
case_sensitive = False
@classmethod
def customise_sources(
cls,
init_settings,
env_settings,
file_secret_settings,
):
# Custom priority: env vars > secrets > init values
return (
env_settings,
file_secret_settings,
init_settings,
)
# Global configuration instance
config = AppConfig()
7. Model Composition and Mixins
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
from abc import ABC, abstractmethod
class TimestampMixin(BaseModel):
"""Mixin for timestamp fields"""
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: Optional[datetime] = None
def touch(self):
"""Update the updated_at timestamp"""
self.updated_at = datetime.utcnow()
class AuditMixin(BaseModel):
"""Mixin for audit trail"""
created_by: Optional[str] = None
updated_by: Optional[str] = None
version: int = Field(default=1, ge=1)
def increment_version(self, updated_by: str):
"""Increment version and set updated_by"""
self.version += 1
self.updated_by = updated_by
class SoftDeleteMixin(BaseModel):
"""Mixin for soft delete functionality"""
deleted_at: Optional[datetime] = None
deleted_by: Optional[str] = None
def soft_delete(self, deleted_by: str):
"""Mark as soft deleted"""
self.deleted_at = datetime.utcnow()
self.deleted_by = deleted_by
@property
def is_deleted(self) -> bool:
return self.deleted_at is not None
class BaseEntity(TimestampMixin, AuditMixin, SoftDeleteMixin):
"""Base entity with common fields"""
id: Optional[int] = None
class User(BaseEntity):
"""User model with all mixins"""
email: str
name: str
is_active: bool = True
class Product(BaseEntity):
"""Product model with all mixins"""
name: str
price: Decimal
category: str
tags: List[str] = Field(default_factory=list)
8. Custom Validators and Constraints
from pydantic import BaseModel, validator, constr, conint
from typing import List, Pattern
import re
class AdvancedValidationModel(BaseModel):
"""Model with advanced validation patterns"""
# Constrained types
username: constr(regex=r'^[a-zA-Z0-9_]{3,20}$', min_length=3, max_length=20)
password: constr(min_length=8)
age: conint(ge=0, le=150)
# Custom validation with dependencies
password_confirm: str
security_questions: List[str] = Field(..., min_items=2, max_items=5)
@validator('password')
def validate_password_strength(cls, v):
"""Validate password strength"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
checks = [
(r'[a-z]', 'lowercase letter'),
(r'[A-Z]', 'uppercase letter'),
(r'\d', 'digit'),
(r'[!@#$%^&*(),.?":{}|<>]', 'special character')
]
missing = []
for pattern, description in checks:
if not re.search(pattern, v):
missing.append(description)
if len(missing) > 1: # Allow missing one type
raise ValueError(f'Password must contain: {", ".join(missing)}')
return v
@validator('password_confirm')
def passwords_match(cls, v, values):
if 'password' in values and v != values['password']:
raise ValueError('Passwords do not match')
return v
@validator('security_questions', each_item=True)
def validate_security_question(cls, v):
if len(v.strip()) < 10:
raise ValueError('Security questions must be at least 10 characters')
if not v.strip().endswith('?'):
raise ValueError('Security questions must end with a question mark')
return v.strip()
Best Practices Summary
1. Performance Optimization
- Use dataclasses for simple data structures
- Cache expensive validations
- Use orjson for faster JSON operations
- Disable unnecessary features in production
2. Custom Types
- Create reusable custom types for common patterns
- Implement proper validation in custom types
- Provide meaningful error messages
3. Model Composition
- Use mixins for common functionality
- Create base classes for shared behavior
- Keep models focused and cohesive
4. Validation Strategy
- Use field validators for simple checks
- Use root validators for complex cross-field validation
- Cache validation results for expensive operations