Provider package apache-airflow-providers-snowflake for Apache Airflow
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Helper functions for parameter formatting, authentication token management, OpenLineage integration, and Snowpark session injection. These utilities provide essential support functionality for secure and efficient Snowflake operations within Airflow workflows.
Security utility for safely formatting SQL parameters to prevent injection attacks and ensure proper parameter handling in Snowflake queries.
def enclose_param(param: str) -> str:
"""
Replace all single quotes in parameter by two single quotes and enclose param in single quotes.
Provides SQL injection protection for Snowflake string parameters.
Parameters:
- param: String parameter to be safely enclosed
Returns:
Safely enclosed parameter string suitable for SQL inclusion
Examples:
- enclose_param("without quotes") returns "'without quotes'"
- enclose_param("'with quotes'") returns "'''with quotes'''"
- enclose_param("Today's sales") returns "'Today''s sales'"
"""Utility function for intelligent session injection based on function signatures, enabling automatic Snowpark session management in callable functions.
def inject_session_into_op_kwargs(
python_callable: Callable,
op_kwargs: dict,
session: Session | None
) -> dict:
"""
Inject Snowpark session into operator kwargs based on callable signature.
Automatically provides session parameter if the callable function defines it.
Parameters:
- python_callable: Python function to analyze for session parameter
- op_kwargs: Current operator keyword arguments
- session: Snowpark Session object to inject
Returns:
Updated kwargs dictionary with session injected if needed
Behavior:
- Inspects callable signature for 'session' parameter
- Injects session only if parameter exists in signature
- Preserves existing kwargs without modification
"""Comprehensive JWT token management for Snowflake SQL API authentication with automatic renewal and secure key handling.
class JWTGenerator:
"""
Creates and signs JWTs for Snowflake SQL API authentication.
Provides automatic token renewal and secure key management.
"""
LIFETIME = timedelta(minutes=59) # Default token lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Default token renewal delay
ALGORITHM = "RS256" # JWT signing algorithm
def __init__(
self,
account: str,
user: str,
private_key: Any,
lifetime: timedelta = LIFETIME,
renewal_delay: timedelta = RENEWAL_DELTA,
):
"""
Initialize JWT generator for Snowflake authentication.
Parameters:
- account: Snowflake account identifier
- user: Snowflake username
- private_key: RSA private key for JWT signing
- lifetime: JWT token lifetime duration
- renewal_delay: Time before expiration to renew token
"""def get_token(self) -> str | None:
"""
Generate new JWT token or return cached valid token.
Automatically handles token renewal based on expiration timing.
Returns:
Valid JWT token string or None if generation fails
Behavior:
- Returns cached token if still valid
- Generates new token if expired or near expiration
- Handles automatic renewal based on renewal_delay
"""
@staticmethod
def prepare_account_name_for_jwt(raw_account: str) -> str:
"""
Prepare account identifier for JWT token generation.
Formats account name according to Snowflake requirements.
Parameters:
- raw_account: Raw account identifier string
Returns:
Properly formatted account identifier for JWT claims
"""
@staticmethod
def calculate_public_key_fingerprint(private_key: Any) -> str:
"""
Calculate public key fingerprint from private key.
Used for JWT key identification in Snowflake authentication.
Parameters:
- private_key: RSA private key object
Returns:
SHA256 fingerprint of the public key in required format
"""# JWT claim field names
ISSUER = "iss" # JWT issuer field name
EXPIRE_TIME = "exp" # JWT expiration field name
ISSUE_TIME = "iat" # JWT issue time field name
SUBJECT = "sub" # JWT subject field nameComprehensive OpenLineage integration utilities for data lineage tracking, providing metadata extraction and event generation for Snowflake operations.
def fix_account_name(name: str) -> str:
"""
Fix account name to proper OpenLineage format.
Converts account identifier to standard format: <account_id>.<region>.<cloud>
Parameters:
- name: Raw Snowflake account identifier
Returns:
Properly formatted account name for OpenLineage
"""
def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
"""
Fix Snowflake SQLAlchemy URI to OpenLineage structure.
Standardizes connection URI format for lineage tracking.
Parameters:
- uri: Raw SQLAlchemy connection URI
Returns:
Standardized URI string for OpenLineage integration
"""def emit_openlineage_events_for_snowflake_queries(
hook: Union[SnowflakeHook, SnowflakeSqlApiHook],
task_instance: TaskInstance,
query_ids: list[str],
sql: str
) -> None:
"""
Emit OpenLineage events for executed Snowflake queries.
Generates comprehensive lineage metadata for data operations.
Parameters:
- hook: Snowflake hook instance (regular or SQL API)
- task_instance: Airflow TaskInstance for context
- query_ids: List of Snowflake query IDs to track
- sql: SQL statements that were executed
Behavior:
- Extracts table and column lineage from query results
- Generates start and complete OpenLineage events
- Includes comprehensive metadata about data transformations
"""Utilities for handling Airflow version compatibility and ensuring consistent behavior across different Airflow versions.
# Version detection constants
AIRFLOW_V_3_0_PLUS: bool # Whether running Airflow 3.0 or later
# Compatible base operator import
BaseOperator # Imports from airflow.sdk for 3.0+, airflow.models for <3.0
def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"""
Get the base Airflow version as a tuple of (major, minor, micro).
Provides version information for compatibility checks.
Returns:
Tuple of (major, minor, micro) version numbers
Example:
- For Airflow 2.10.1 returns (2, 10, 1)
- For Airflow 3.0.0 returns (3, 0, 0)
"""from airflow.providers.snowflake.utils.common import enclose_param
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
def safe_query_execution(customer_name: str, region: str):
"""Execute query with safely enclosed parameters."""
hook = SnowflakeHook(snowflake_conn_id='snowflake_prod')
# Safely enclose user-provided parameters
safe_customer = enclose_param(customer_name)
safe_region = enclose_param(region)
# Build secure SQL query
sql = f"""
SELECT *
FROM customers
WHERE customer_name = {safe_customer}
AND region = {safe_region}
"""
result = hook.run(sql, handler=lambda cursor: cursor.fetchall())
return result
# Example usage
customers = safe_query_execution("O'Reilly Industries", "North America")from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs
from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
def my_data_processing_function(data_param: str, session=None, **kwargs):
"""
Data processing function that optionally uses Snowpark session.
Session will be automatically injected if available.
"""
if session:
# Use Snowpark DataFrame API
df = session.table("raw_data")
return df.filter(df.col("category") == data_param).count()
else:
# Fallback to regular processing
print(f"Processing {data_param} without Snowpark")
return 0
def create_dynamic_snowpark_task(python_callable, **op_kwargs):
"""Create Snowpark task with intelligent session injection."""
# This would be called internally by SnowparkOperator
session = get_snowpark_session() # Hypothetical session creation
# Inject session only if function signature requires it
enhanced_kwargs = inject_session_into_op_kwargs(
python_callable=python_callable,
op_kwargs=op_kwargs,
session=session
)
return python_callable(**enhanced_kwargs)from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
from datetime import timedelta
import os
def setup_jwt_authentication():
"""Set up JWT authentication for Snowflake SQL API."""
# Load private key from secure storage
private_key_path = os.getenv('SNOWFLAKE_PRIVATE_KEY_PATH')
with open(private_key_path, 'rb') as key_file:
private_key = key_file.read()
# Initialize JWT generator with custom lifetime
jwt_gen = JWTGenerator(
account='mycompany.us-east-1',
user='airflow_service',
private_key=private_key,
lifetime=timedelta(hours=1), # 1-hour tokens
renewal_delay=timedelta(minutes=50) # Renew 10 minutes before expiry
)
# Get token (will be cached and automatically renewed)
token = jwt_gen.get_token()
return token
def make_authenticated_api_call(query: str):
"""Make Snowflake SQL API call with JWT authentication."""
token = setup_jwt_authentication()
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json',
'X-Snowflake-Authorization-Token-Type': 'KEYPAIR_JWT'
}
# Make API call with authenticated headers
# ... API call implementationfrom airflow.providers.snowflake.utils.openlineage import (
emit_openlineage_events_for_snowflake_queries,
fix_account_name,
fix_snowflake_sqlalchemy_uri
)
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
def execute_with_lineage_tracking(task_instance, sql_statements: str):
"""Execute SQL with comprehensive lineage tracking."""
# Initialize hook with proper account format
hook = SnowflakeSqlApiHook(snowflake_conn_id='snowflake_prod')
# Execute queries and get query IDs
query_ids = hook.execute_query(
sql=sql_statements,
statement_count=2,
query_tag='lineage_tracked_operation'
)
# Wait for completion
for query_id in query_ids:
hook.wait_for_query(query_id, raise_error=True)
# Emit OpenLineage events for data lineage
emit_openlineage_events_for_snowflake_queries(
hook=hook,
task_instance=task_instance,
query_ids=query_ids,
sql=sql_statements
)
return query_ids
# Usage in custom operator
class LineageAwareSnowflakeOperator(SnowflakeSqlApiOperator):
"""Custom operator with automatic lineage tracking."""
def execute(self, context):
# Execute queries
result = super().execute(context)
# Add lineage tracking
if hasattr(self, '_query_ids') and self._query_ids:
emit_openlineage_events_for_snowflake_queries(
hook=self.hook,
task_instance=context['task_instance'],
query_ids=self._query_ids,
sql=self.sql
)
return resultfrom airflow.providers.snowflake.version_compat import (
AIRFLOW_V_3_0_PLUS,
BaseOperator,
get_base_airflow_version_tuple
)
def create_compatible_operator():
"""Create operator compatible with multiple Airflow versions."""
version_tuple = get_base_airflow_version_tuple()
if AIRFLOW_V_3_0_PLUS:
# Use Airflow 3.0+ features
print(f"Running on Airflow {version_tuple}, using SDK features")
# BaseOperator imported from airflow.sdk
class ModernSnowflakeOperator(BaseOperator):
def execute(self, context):
# Modern execution logic
pass
else:
# Use legacy Airflow features
print(f"Running on Airflow {version_tuple}, using legacy features")
# BaseOperator imported from airflow.models
class LegacySnowflakeOperator(BaseOperator):
def execute(self, context):
# Legacy execution logic
pass
return ModernSnowflakeOperator if AIRFLOW_V_3_0_PLUS else LegacySnowflakeOperatorfrom airflow import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
from airflow.providers.snowflake.utils.common import enclose_param
from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
from datetime import datetime, timedelta
def create_secure_analytics_pipeline():
"""Create analytics pipeline with comprehensive security and monitoring."""
with DAG(
'secure_analytics_pipeline',
start_date=datetime(2024, 1, 1),
schedule_interval='@daily',
catchup=False
) as dag:
def build_secure_query(region: str, category: str) -> str:
"""Build SQL query with safe parameter handling."""
# Safely encode parameters
safe_region = enclose_param(region)
safe_category = enclose_param(category)
return f"""
CREATE OR REPLACE TABLE analytics.regional_analysis AS
SELECT
region,
product_category,
COUNT(*) as transaction_count,
SUM(amount) as total_revenue,
AVG(amount) as avg_transaction
FROM raw.transactions
WHERE region = {safe_region}
AND product_category = {safe_category}
AND transaction_date = '{{{{ ds }}}}'
GROUP BY region, product_category;
-- Create summary statistics
INSERT INTO analytics.daily_summary
SELECT
'{{{{ ds }}}}' as summary_date,
{safe_region} as region,
{safe_category} as category,
SUM(transaction_count) as total_transactions,
SUM(total_revenue) as total_revenue
FROM analytics.regional_analysis;
"""
# Dynamic task creation with secure parameters
regions = ['North', 'South', 'East', 'West']
categories = ['Electronics', 'Clothing', 'Home & Garden']
tasks = []
for region in regions:
for category in categories:
task = SnowflakeSqlApiOperator(
task_id=f'analyze_{region}_{category}'.lower().replace(' ', '_').replace('&', 'and'),
snowflake_conn_id='snowflake_prod',
sql=build_secure_query(region, category),
statement_count=2,
deferrable=True,
poll_interval=20,
warehouse='ANALYTICS_WH',
# Custom JWT configuration
token_life_time=timedelta(hours=2),
token_renewal_delta=timedelta(minutes=90),
session_parameters={
'QUERY_TAG': f'secure_analytics_{region}_{category}',
'AUTOCOMMIT': True
}
)
tasks.append(task)
return tasks
# Create pipeline with all security and monitoring features
pipeline_tasks = create_secure_analytics_pipeline()enclose_param() for user-provided stringsAll utility functions provide comprehensive error handling:
Error messages include specific context and suggestions for resolution.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-snowflake