Provider package apache-airflow-providers-snowflake for Apache Airflow
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Deferrable task execution through triggers, enabling efficient resource utilization for long-running Snowflake operations without blocking worker slots. This capability allows Airflow workers to handle other tasks while Snowflake queries execute, improving overall system throughput and resource efficiency.
Trigger for polling Snowflake SQL API query status in deferrable mode, providing asynchronous monitoring of long-running query execution with configurable polling intervals and comprehensive status reporting.
class SnowflakeSqlApiTrigger(BaseTrigger):
"""
Trigger for polling Snowflake SQL API query status in deferrable mode.
Monitors query execution progress and triggers task completion when queries finish.
"""
def __init__(
self,
poll_interval: float,
query_ids: list[str],
snowflake_conn_id: str,
token_life_time: timedelta,
token_renewal_delta: timedelta,
):
"""
Initialize SQL API trigger.
Parameters:
- poll_interval: Polling interval in seconds for checking query status
- query_ids: List of Snowflake query IDs to monitor
- snowflake_conn_id: Snowflake connection ID for API access
- token_life_time: JWT token lifetime for authentication
- token_renewal_delta: JWT token renewal interval
"""def serialize(self) -> tuple[str, dict[str, Any]]:
"""
Serialize trigger arguments and classpath for persistence.
Returns:
Tuple of (classpath, serialized_arguments) for trigger reconstruction
"""
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Wait for Snowflake queries to complete and yield trigger events.
Continuously polls query status until all queries complete or fail.
Yields:
TriggerEvent objects containing query status and completion information
"""
async def get_query_status(self, query_id: str) -> dict[str, Any]:
"""
Get query status asynchronously from Snowflake SQL API.
Parameters:
- query_id: Snowflake query ID to check
Returns:
Dictionary containing query status, metadata, and execution details
"""
def _set_context(self, context):
"""
Set trigger context (no-op implementation for compatibility).
Parameters:
- context: Trigger execution context
"""from airflow import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
from datetime import datetime, timedelta
with DAG(
'deferrable_snowflake_example',
start_date=datetime(2024, 1, 1),
schedule_interval='@daily',
catchup=False
) as dag:
# Long-running data processing with deferrable execution
heavy_processing = SnowflakeSqlApiOperator(
task_id='heavy_data_processing',
snowflake_conn_id='snowflake_prod',
sql='''
-- Large table creation and transformation
CREATE OR REPLACE TABLE analytics.customer_360_view AS
SELECT
c.customer_id,
c.customer_name,
c.registration_date,
COUNT(DISTINCT o.order_id) as total_orders,
SUM(o.order_amount) as lifetime_value,
AVG(o.order_amount) as avg_order_value,
MAX(o.order_date) as last_order_date,
MIN(o.order_date) as first_order_date,
COUNT(DISTINCT DATE_TRUNC('month', o.order_date)) as active_months
FROM customers c
LEFT JOIN orders o ON c.customer_id = o.customer_id
WHERE c.registration_date >= '2020-01-01'
GROUP BY c.customer_id, c.customer_name, c.registration_date;
-- Create summary statistics
CREATE OR REPLACE TABLE analytics.customer_segments AS
SELECT
CASE
WHEN lifetime_value >= 10000 THEN 'High Value'
WHEN lifetime_value >= 1000 THEN 'Medium Value'
ELSE 'Low Value'
END as segment,
COUNT(*) as customer_count,
AVG(lifetime_value) as avg_segment_value,
AVG(total_orders) as avg_orders_per_customer
FROM analytics.customer_360_view
GROUP BY 1;
''',
statement_count=2,
deferrable=True, # Enable deferrable execution
poll_interval=30, # Check status every 30 seconds
warehouse='X_LARGE_WH',
database='ANALYTICS'
)from airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
# Multiple independent long-running operations
parallel_processing = [
SnowflakeSqlApiOperator(
task_id=f'process_region_{region}',
snowflake_conn_id='snowflake_prod',
sql=f'''
CREATE OR REPLACE TABLE analytics.regional_summary_{region.lower()} AS
SELECT
DATE_TRUNC('month', order_date) as month,
COUNT(*) as total_orders,
SUM(order_amount) as total_revenue,
COUNT(DISTINCT customer_id) as unique_customers
FROM orders
WHERE region = '{region}'
AND order_date >= '2023-01-01'
GROUP BY 1
ORDER BY 1;
''',
statement_count=1,
deferrable=True,
poll_interval=15, # Faster polling for smaller queries
warehouse='LARGE_WH',
session_parameters={
'QUERY_TAG': f'regional_processing_{region}_{datetime.now().isoformat()}'
}
)
for region in ['NORTH', 'SOUTH', 'EAST', 'WEST']
]
# All regional processing tasks run in parallel without blocking workers
for task in parallel_processing:
taskfrom airflow import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator
with DAG(
'deferrable_etl_pipeline',
start_date=datetime(2024, 1, 1),
schedule_interval='@daily',
catchup=False,
max_active_runs=1
) as dag:
# Stage 1: Data ingestion (synchronous - typically fast)
ingest_data = CopyFromExternalStageToSnowflakeOperator(
task_id='ingest_raw_data',
table='raw.daily_transactions',
stage='@s3_data_stage',
prefix='transactions/{{ ds }}/',
file_format='csv_transactions',
warehouse='LOADING_WH'
)
# Stage 2: Heavy data processing (deferrable)
process_transactions = SnowflakeSqlApiOperator(
task_id='process_transactions',
snowflake_conn_id='snowflake_prod',
sql='''
-- Clean and standardize transaction data
CREATE OR REPLACE TABLE staging.clean_transactions AS
SELECT
transaction_id,
customer_id,
UPPER(TRIM(product_category)) as product_category,
ROUND(transaction_amount, 2) as transaction_amount,
transaction_date,
CASE
WHEN payment_method IN ('CREDIT', 'DEBIT', 'CASH') THEN payment_method
ELSE 'OTHER'
END as payment_method_clean
FROM raw.daily_transactions
WHERE transaction_amount > 0
AND customer_id IS NOT NULL
AND transaction_date = '{{ ds }}';
-- Create enriched transaction view with customer data
CREATE OR REPLACE TABLE staging.enriched_transactions AS
SELECT
t.*,
c.customer_segment,
c.customer_tier,
c.registration_date,
DATEDIFF('day', c.registration_date, t.transaction_date) as days_since_registration
FROM staging.clean_transactions t
JOIN dim.customers c ON t.customer_id = c.customer_id;
-- Aggregate daily metrics by segment
CREATE OR REPLACE TABLE analytics.daily_segment_metrics AS
SELECT
'{{ ds }}' as metric_date,
customer_segment,
product_category,
payment_method_clean,
COUNT(*) as transaction_count,
SUM(transaction_amount) as total_revenue,
AVG(transaction_amount) as avg_transaction_value,
COUNT(DISTINCT customer_id) as unique_customers
FROM staging.enriched_transactions
GROUP BY customer_segment, product_category, payment_method_clean;
''',
statement_count=3,
deferrable=True,
poll_interval=20,
warehouse='HEAVY_PROCESSING_WH',
token_life_time=timedelta(hours=2) # Extended token lifetime for long operations
)
# Stage 3: ML feature generation (deferrable)
generate_ml_features = SnowflakeSqlApiOperator(
task_id='generate_ml_features',
snowflake_conn_id='snowflake_prod',
sql='''
-- Generate rolling window features
CREATE OR REPLACE TABLE ml.customer_features_{{ ds | regex_replace('-', '_') }} AS
SELECT
customer_id,
'{{ ds }}' as feature_date,
-- 7-day rolling features
COUNT(*) OVER (
PARTITION BY customer_id
ORDER BY transaction_date
RANGE BETWEEN INTERVAL '7 days' PRECEDING AND CURRENT ROW
) as transactions_7d,
SUM(transaction_amount) OVER (
PARTITION BY customer_id
ORDER BY transaction_date
RANGE BETWEEN INTERVAL '7 days' PRECEDING AND CURRENT ROW
) as revenue_7d,
-- 30-day rolling features
COUNT(*) OVER (
PARTITION BY customer_id
ORDER BY transaction_date
RANGE BETWEEN INTERVAL '30 days' PRECEDING AND CURRENT ROW
) as transactions_30d,
SUM(transaction_amount) OVER (
PARTITION BY customer_id
ORDER BY transaction_date
RANGE BETWEEN INTERVAL '30 days' PRECEDING AND CURRENT ROW
) as revenue_30d,
-- Recency features
DATEDIFF('day',
LAG(transaction_date) OVER (PARTITION BY customer_id ORDER BY transaction_date),
transaction_date
) as days_since_last_transaction
FROM staging.enriched_transactions
ORDER BY customer_id, transaction_date;
-- Update master feature table
MERGE INTO ml.customer_features_master m
USING ml.customer_features_{{ ds | regex_replace('-', '_') }} f
ON m.customer_id = f.customer_id AND m.feature_date = f.feature_date
WHEN MATCHED THEN UPDATE SET
transactions_7d = f.transactions_7d,
revenue_7d = f.revenue_7d,
transactions_30d = f.transactions_30d,
revenue_30d = f.revenue_30d,
days_since_last_transaction = f.days_since_last_transaction
WHEN NOT MATCHED THEN INSERT (
customer_id, feature_date, transactions_7d, revenue_7d,
transactions_30d, revenue_30d, days_since_last_transaction
) VALUES (
f.customer_id, f.feature_date, f.transactions_7d, f.revenue_7d,
f.transactions_30d, f.revenue_30d, f.days_since_last_transaction
);
''',
statement_count=2,
deferrable=True,
poll_interval=25,
warehouse='ML_WH'
)
# Stage 4: Data quality validation (synchronous - fast)
validate_results = SnowflakeSqlApiOperator(
task_id='validate_processing_results',
snowflake_conn_id='snowflake_prod',
sql='''
-- Validate record counts match expectations
SELECT
CASE
WHEN staging_count > 0 AND
staging_count = analytics_count AND
ml_count > 0
THEN 'PASSED'
ELSE 'FAILED'
END as validation_result
FROM (
SELECT
(SELECT COUNT(*) FROM staging.clean_transactions) as staging_count,
(SELECT SUM(transaction_count) FROM analytics.daily_segment_metrics WHERE metric_date = '{{ ds }}') as analytics_count,
(SELECT COUNT(DISTINCT customer_id) FROM ml.customer_features_master WHERE feature_date = '{{ ds }}') as ml_count
);
''',
statement_count=1,
warehouse='VALIDATION_WH'
)
# Define pipeline dependencies
ingest_data >> process_transactions >> generate_ml_features >> validate_resultsfrom airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
class CustomSnowflakeOperator(SnowflakeSqlApiOperator):
"""Custom operator with enhanced trigger event handling."""
def execute_complete(self, context, event=None):
"""Custom completion handler with detailed logging."""
# Extract query results from trigger event
if event and 'query_results' in event:
query_results = event['query_results']
# Log execution statistics
for query_id, result in query_results.items():
if result.get('status') == 'SUCCESS':
execution_time = result.get('execution_time_ms', 0) / 1000
rows_affected = result.get('rows_affected', 0)
self.log.info(
f"Query {query_id} completed successfully: "
f"{rows_affected} rows affected in {execution_time:.2f} seconds"
)
# Store metrics for monitoring
context['task_instance'].xcom_push(
key=f'query_{query_id}_metrics',
value={
'execution_time_seconds': execution_time,
'rows_affected': rows_affected,
'status': 'SUCCESS'
}
)
else:
self.log.error(f"Query {query_id} failed: {result.get('error_message', 'Unknown error')}")
raise Exception(f"Query execution failed: {result.get('error_message')}")
# Call parent completion handler
super().execute_complete(context, event)
# Usage of custom operator
custom_deferrable_task = CustomSnowflakeOperator(
task_id='custom_processing_with_metrics',
sql='SELECT COUNT(*) FROM large_table WHERE date >= CURRENT_DATE - 30',
statement_count=1,
deferrable=True,
poll_interval=10
)The SnowflakeSqlApiTrigger yields TriggerEvent objects with the following structure:
TriggerEvent({
"status": "success" | "error",
"query_results": {
"query_id_1": {
"status": "SUCCESS" | "FAILED" | "RUNNING",
"execution_time_ms": 1234,
"rows_affected": 567,
"error_message": "...", # Only present on failure
"query_text": "...",
"warehouse": "..."
},
# ... additional query results
},
"message": "All queries completed successfully" | "Error details"
})token_life_time for expected query durationtoken_renewal_delta to ensure tokens don't expire during execution-- Monitor long-running queries
SELECT
query_id,
query_text,
user_name,
warehouse_name,
start_time,
end_time,
total_elapsed_time,
execution_status
FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY())
WHERE start_time >= DATEADD(hour, -24, CURRENT_TIMESTAMP())
AND execution_status IN ('RUNNING', 'QUEUED')
ORDER BY start_time DESC;All trigger events are logged with detailed information about query execution, including:
The deferrable execution system provides comprehensive error handling:
All errors include detailed logging, Snowflake query IDs for investigation, and clear guidance for resolution.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-snowflake