State management, caching mechanisms, and performance optimization tools for maintaining application state and improving performance through intelligent data persistence.
Persistent state storage across user interactions and app reruns.
# Session state proxy object
session_state: SessionStateProxy
# Access patterns:
# st.session_state.key = value # Set value
# value = st.session_state.key # Get value
# st.session_state["key"] = value # Dictionary-style set
# value = st.session_state["key"] # Dictionary-style get
# del st.session_state.key # Delete key
# "key" in st.session_state # Check existence
# st.session_state.clear() # Clear all stateURL query parameter handling for shareable application state.
# Query parameters proxy object
query_params: QueryParamsProxy
# Access patterns:
# st.query_params.key = value # Set parameter
# value = st.query_params.key # Get parameter
# st.query_params["key"] = value # Dictionary-style set
# value = st.query_params["key"] # Dictionary-style get
# del st.query_params.key # Delete parameter
# st.query_params.clear() # Clear all parameters
# dict(st.query_params) # Convert to dictionaryCache function execution results to improve performance and reduce redundant computations.
def cache_data(func=None, *, ttl=None, max_entries=None, show_spinner=True, persist=None, experimental_allow_widgets=False, hash_funcs=None, max_entries_per_session=None, validate=None):
"""
Function decorator to cache the result of function calls.
Parameters:
- func (callable): Function to cache (auto-filled when used as decorator)
- ttl (float): Time to live in seconds (None for no expiration)
- max_entries (int): Maximum number of cached entries (None for unlimited)
- show_spinner (bool): Show spinner while function executes
- persist (bool): Persist cache across app restarts (experimental)
- experimental_allow_widgets (bool): Allow widgets inside cached function
- hash_funcs (dict): Custom hash functions for parameters
- max_entries_per_session (int): Maximum entries per session
- validate (callable): Function to validate cached values
Returns:
callable: Decorated function with caching capability
"""Cache global resources like database connections, models, and expensive objects.
def cache_resource(func=None, *, ttl=None, max_entries=None, show_spinner=True, validate=None, experimental_allow_widgets=False, hash_funcs=None):
"""
Function decorator to cache global resources.
Parameters:
- func (callable): Function to cache (auto-filled when used as decorator)
- ttl (float): Time to live in seconds (None for no expiration)
- max_entries (int): Maximum number of cached resources (None for unlimited)
- show_spinner (bool): Show spinner while function executes
- validate (callable): Function to validate cached resources
- experimental_allow_widgets (bool): Allow widgets inside cached function
- hash_funcs (dict): Custom hash functions for parameters
Returns:
callable: Decorated function with resource caching capability
"""Deprecated caching decorator for backward compatibility.
def cache(func=None, **kwargs):
"""
Legacy caching decorator (deprecated).
Note: This function is deprecated. Use st.cache_data or st.cache_resource instead.
Parameters:
- func (callable): Function to cache
- **kwargs: Cache configuration options
Returns:
callable: Decorated function with caching
"""import streamlit as st
# Initialize session state
if 'count' not in st.session_state:
st.session_state.count = 0
if 'user_name' not in st.session_state:
st.session_state.user_name = ''
# Display current state
st.write(f"Count: {st.session_state.count}")
st.write(f"User: {st.session_state.user_name}")
# Modify state with buttons
col1, col2, col3 = st.columns(3)
with col1:
if st.button('Increment'):
st.session_state.count += 1
with col2:
if st.button('Decrement'):
st.session_state.count -= 1
with col3:
if st.button('Reset'):
st.session_state.count = 0
# Update user name
new_name = st.text_input('Enter your name:', value=st.session_state.user_name)
if new_name != st.session_state.user_name:
st.session_state.user_name = new_name# Session state with complex data structures
if 'shopping_cart' not in st.session_state:
st.session_state.shopping_cart = []
if 'user_preferences' not in st.session_state:
st.session_state.user_preferences = {
'theme': 'light',
'notifications': True,
'language': 'en'
}
# Add item to cart
def add_to_cart(item, price):
st.session_state.shopping_cart.append({
'item': item,
'price': price,
'quantity': 1
})
# Display cart
st.subheader("Shopping Cart")
if st.session_state.shopping_cart:
for i, item in enumerate(st.session_state.shopping_cart):
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.write(item['item'])
with col2:
st.write(f"${item['price']:.2f}")
with col3:
if st.button('Remove', key=f'remove_{i}'):
st.session_state.shopping_cart.pop(i)
st.rerun()
total = sum(item['price'] for item in st.session_state.shopping_cart)
st.write(f"**Total: ${total:.2f}**")
else:
st.write("Cart is empty")
# Add new items
st.subheader("Add Items")
item_name = st.text_input("Item name:")
item_price = st.number_input("Price:", min_value=0.01, format="%.2f")
if st.button("Add to Cart") and item_name:
add_to_cart(item_name, item_price)
st.success(f"Added {item_name} to cart!")# Initialize from query parameters
if 'page' not in st.query_params:
st.query_params.page = 'home'
if 'filter' not in st.query_params:
st.query_params.filter = 'all'
# Navigation that updates URL
st.sidebar.title("Navigation")
pages = ['home', 'analytics', 'settings']
current_page = st.sidebar.selectbox(
"Select page:",
pages,
index=pages.index(st.query_params.page) if st.query_params.page in pages else 0
)
# Update query params when page changes
if current_page != st.query_params.page:
st.query_params.page = current_page
st.rerun()
# Filters that update URL
filters = ['all', 'active', 'inactive']
current_filter = st.sidebar.selectbox(
"Filter:",
filters,
index=filters.index(st.query_params.filter) if st.query_params.filter in filters else 0
)
if current_filter != st.query_params.filter:
st.query_params.filter = current_filter
st.rerun()
# Display content based on query params
st.title(f"Page: {st.query_params.page}")
st.write(f"Filter: {st.query_params.filter}")
# Show shareable URL
st.info(f"Share this URL: {st.query_params}")import pandas as pd
import time
import requests
@st.cache_data
def load_data(url):
"""Load data from URL with caching."""
st.info(f"Loading data from {url}...")
response = requests.get(url)
return pd.read_csv(response.content)
@st.cache_data(ttl=300) # Cache for 5 minutes
def expensive_computation(n):
"""Simulate expensive computation with TTL."""
st.info("Running expensive computation...")
time.sleep(2) # Simulate processing time
return sum(i**2 for i in range(n))
@st.cache_data(max_entries=10)
def process_data(data, operation):
"""Process data with limited cache size."""
st.info(f"Processing data with {operation}...")
if operation == 'sum':
return data.sum()
elif operation == 'mean':
return data.mean()
elif operation == 'max':
return data.max()
else:
return data
# Use cached functions
st.subheader("Cached Data Loading")
# This will only run once per URL
try:
df = load_data("https://raw.githubusercontent.com/datasets/iris/master/data/iris.csv")
st.dataframe(df.head())
except Exception as e:
st.error(f"Failed to load data: {e}")
# Expensive computation with TTL
st.subheader("Cached Computation")
n = st.slider("Computation size:", 1000, 10000, 5000)
result = expensive_computation(n)
st.write(f"Result: {result}")
# Data processing with limited cache
if 'df' in locals():
operation = st.selectbox("Operation:", ['sum', 'mean', 'max'])
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) > 0:
processed = process_data(df[numeric_cols[0]], operation)
st.write(f"Result: {processed}")import sqlite3
from datetime import datetime
@st.cache_resource
def init_database():
"""Initialize database connection (cached resource)."""
st.info("Initializing database connection...")
conn = sqlite3.connect(':memory:')
# Create sample table
conn.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT,
email TEXT,
created_at TIMESTAMP
)
''')
# Insert sample data
sample_users = [
('Alice', 'alice@example.com'),
('Bob', 'bob@example.com'),
('Charlie', 'charlie@example.com')
]
for name, email in sample_users:
conn.execute(
'INSERT INTO users (name, email, created_at) VALUES (?, ?, ?)',
(name, email, datetime.now())
)
conn.commit()
return conn
@st.cache_resource
def load_model():
"""Load ML model (simulated)."""
st.info("Loading machine learning model...")
time.sleep(1) # Simulate model loading time
return {"model_type": "random_forest", "accuracy": 0.95, "loaded_at": datetime.now()}
# Use cached resources
st.subheader("Cached Database Connection")
db = init_database()
# Query database
users = db.execute('SELECT * FROM users').fetchall()
st.write("Users in database:")
for user in users:
st.write(f"- {user[1]} ({user[2]})")
st.subheader("Cached Model")
model = load_model()
st.json(model)# Cache inspection and management
st.subheader("Cache Management")
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Clear Data Cache"):
st.cache_data.clear()
st.success("Data cache cleared!")
with col2:
if st.button("Clear Resource Cache"):
st.cache_resource.clear()
st.success("Resource cache cleared!")
with col3:
if st.button("Clear All Caches"):
st.cache_data.clear()
st.cache_resource.clear()
st.success("All caches cleared!")
# Custom hash function example
@st.cache_data(
hash_funcs={pd.DataFrame: lambda df: df.shape},
show_spinner="Processing dataframe..."
)
def analyze_dataframe(df, analysis_type):
"""Analyze dataframe with custom hashing."""
time.sleep(1) # Simulate analysis
if analysis_type == "describe":
return df.describe()
elif analysis_type == "info":
return f"Shape: {df.shape}, Columns: {list(df.columns)}"
else:
return "Unknown analysis type"
# Use function with custom hashing
if 'df' in locals():
analysis = st.selectbox("Analysis type:", ["describe", "info"])
result = analyze_dataframe(df, analysis)
st.write(result)# Multi-step form with state persistence
if 'form_step' not in st.session_state:
st.session_state.form_step = 1
if 'form_data' not in st.session_state:
st.session_state.form_data = {}
st.subheader(f"Multi-step Form - Step {st.session_state.form_step}/3")
# Progress bar
progress = st.session_state.form_step / 3
st.progress(progress)
if st.session_state.form_step == 1:
st.write("**Personal Information**")
name = st.text_input("Name:", value=st.session_state.form_data.get('name', ''))
email = st.text_input("Email:", value=st.session_state.form_data.get('email', ''))
if st.button("Next") and name and email:
st.session_state.form_data.update({'name': name, 'email': email})
st.session_state.form_step = 2
st.rerun()
elif st.session_state.form_step == 2:
st.write("**Preferences**")
theme = st.selectbox("Theme:", ["Light", "Dark"],
index=["Light", "Dark"].index(st.session_state.form_data.get('theme', 'Light')))
notifications = st.checkbox("Enable notifications",
value=st.session_state.form_data.get('notifications', False))
col1, col2 = st.columns(2)
with col1:
if st.button("Back"):
st.session_state.form_step = 1
st.rerun()
with col2:
if st.button("Next"):
st.session_state.form_data.update({'theme': theme, 'notifications': notifications})
st.session_state.form_step = 3
st.rerun()
elif st.session_state.form_step == 3:
st.write("**Review & Submit**")
st.json(st.session_state.form_data)
col1, col2 = st.columns(2)
with col1:
if st.button("Back"):
st.session_state.form_step = 2
st.rerun()
with col2:
if st.button("Submit"):
st.success("Form submitted successfully!")
st.balloons()
# Reset form
st.session_state.form_step = 1
st.session_state.form_data = {}@st.cache_data(
validate=lambda cached_value: len(cached_value) > 0,
ttl=60
)
def get_api_data(endpoint):
"""Fetch API data with validation."""
# Simulate API call
time.sleep(0.5)
if endpoint == "valid":
return ["item1", "item2", "item3"]
else:
return [] # This will invalidate the cache
# Use validated cache
endpoint = st.selectbox("API Endpoint:", ["valid", "empty"])
data = get_api_data(endpoint)
if data:
st.success(f"Got {len(data)} items from cache/API")
st.write(data)
else:
st.warning("No data available (cache invalidated)")