Provider package apache-airflow-providers-snowflake for Apache Airflow
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Native Snowpark Python integration enabling DataFrame-based data processing workflows directly within Airflow tasks. This capability provides automatic Snowpark session management, seamless integration with Airflow's task execution model, and native Python-based data transformations that run directly in Snowflake's compute environment.
Operator for executing Python functions with Snowpark integration, automatically injecting a configured Snowpark session into the callable function.
class SnowparkOperator(PythonOperator):
"""
Execute Python function with Snowpark Python code.
Automatically injects a Snowpark session configured with connection parameters.
"""
def __init__(
self,
*,
snowflake_conn_id: str = "snowflake_default",
python_callable: Callable,
op_args: Collection[Any] | None = None,
op_kwargs: Mapping[str, Any] | None = None,
templates_dict: dict[str, Any] | None = None,
templates_exts: Sequence[str] | None = None,
show_return_value_in_logs: bool = True,
warehouse: str | None = None,
database: str | None = None,
schema: str | None = None,
role: str | None = None,
authenticator: str | None = None,
session_parameters: dict | None = None,
**kwargs,
):
"""
Initialize Snowpark operator.
Parameters:
- snowflake_conn_id: Snowflake connection ID
- python_callable: Python function to execute with Snowpark session
- op_args: Positional arguments for python_callable
- op_kwargs: Keyword arguments for python_callable
- templates_dict: Dictionary of templates for Jinja templating
- templates_exts: File extensions to apply Jinja templating
- show_return_value_in_logs: Show function return value in logs
- warehouse: Snowflake warehouse name
- database: Snowflake database name
- schema: Snowflake schema name
- role: Snowflake role name
- authenticator: Authentication method
- session_parameters: Session-level parameters
"""def execute_callable(self):
"""
Execute the callable with Snowpark session injection.
Automatically provides 'session' parameter to callable if defined in signature.
Returns:
Result of python_callable execution
"""Decorator function for converting regular Python functions into Snowpark-enabled Airflow tasks with automatic session management.
def snowpark_task(
python_callable: Callable | None = None,
multiple_outputs: bool | None = None,
**kwargs,
) -> TaskDecorator:
"""
Decorator to wrap a function containing Snowpark code into an Airflow operator.
Parameters:
- python_callable: Function to be decorated (auto-provided when used as decorator)
- multiple_outputs: Enable multiple outputs for XCom
- **kwargs: Additional arguments passed to SnowparkOperator
Returns:
TaskDecorator for creating Snowpark tasks
"""class _SnowparkDecoratedOperator(DecoratedOperator, SnowparkOperator):
"""
Internal decorated operator for Snowpark tasks.
Combines DecoratedOperator functionality with Snowpark session management.
"""
custom_operator_name = "@task.snowpark"from airflow import DAG
from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
from datetime import datetime
def process_sales_data(session, **context):
"""
Process sales data using Snowpark DataFrame API.
Args:
session: Snowpark session (automatically injected)
**context: Airflow context variables
"""
# Read data using Snowpark
raw_sales = session.table("raw.sales_transactions")
# Transform data using DataFrame API
daily_sales = (
raw_sales
.filter(raw_sales.col("transaction_date") >= context['ds'])
.filter(raw_sales.col("transaction_date") < context['next_ds'])
.group_by("region", "product_category")
.agg({
"amount": "sum",
"transaction_id": "count"
})
.with_column_renamed("SUM(AMOUNT)", "total_sales")
.with_column_renamed("COUNT(TRANSACTION_ID)", "transaction_count")
)
# Write results back to Snowflake
daily_sales.write.save_as_table(
"analytics.daily_sales_summary",
mode="append"
)
# Return metrics for downstream tasks
total_records = daily_sales.count()
return {"processed_records": total_records}
with DAG(
'snowpark_processing_example',
start_date=datetime(2024, 1, 1),
schedule_interval='@daily',
catchup=False
) as dag:
process_data = SnowparkOperator(
task_id='process_daily_sales',
snowflake_conn_id='snowflake_prod',
python_callable=process_sales_data,
warehouse='ANALYTICS_WH',
database='ANALYTICS_DB',
schema='PUBLIC'
)from airflow import DAG
from airflow.providers.snowflake.decorators.snowpark import snowpark_task
from datetime import datetime
@snowpark_task(
snowflake_conn_id='snowflake_prod',
warehouse='ML_WH',
database='FEATURE_STORE'
)
def create_ml_features(session, **context):
"""
Create machine learning features using Snowpark.
Args:
session: Snowpark session (automatically injected)
"""
# Load base tables
customers = session.table("raw.customers")
orders = session.table("raw.orders")
# Create feature engineering pipeline
customer_features = (
customers
.join(orders, customers.col("customer_id") == orders.col("customer_id"), "left")
.group_by("customer_id", "customer_segment", "registration_date")
.agg({
"order_amount": "sum",
"order_id": "count",
"order_date": "max"
})
.with_column_renamed("SUM(ORDER_AMOUNT)", "lifetime_value")
.with_column_renamed("COUNT(ORDER_ID)", "total_orders")
.with_column_renamed("MAX(ORDER_DATE)", "last_order_date")
)
# Add derived features
from snowflake.snowpark.functions import col, when, datediff, current_date
enriched_features = customer_features.with_columns([
when(col("total_orders") > 10, "high_value")
.when(col("total_orders") > 5, "medium_value")
.otherwise("low_value").alias("customer_tier"),
datediff("day", col("last_order_date"), current_date()).alias("days_since_last_order"),
(col("lifetime_value") / col("total_orders")).alias("avg_order_value")
])
# Save feature table
enriched_features.write.save_as_table(
"features.customer_features_v1",
mode="overwrite"
)
return enriched_features.count()
@snowpark_task(
snowflake_conn_id='snowflake_prod',
warehouse='ML_WH'
)
def train_model_features(session, **context):
"""
Prepare training dataset using Snowpark ML functions.
"""
# Load feature table
features = session.table("features.customer_features_v1")
# Prepare training data with labels
training_data = (
features
.filter(col("days_since_last_order") <= 365) # Active customers only
.with_column(
"will_churn",
when(col("days_since_last_order") > 90, 1).otherwise(0)
)
.select([
"customer_id",
"lifetime_value",
"total_orders",
"avg_order_value",
"days_since_last_order",
"will_churn"
])
)
# Save training dataset
training_data.write.save_as_table(
"ml.churn_training_data",
mode="overwrite"
)
return training_data.count()
with DAG(
'ml_feature_pipeline',
start_date=datetime(2024, 1, 1),
schedule_interval='@weekly',
catchup=False
) as dag:
# Tasks are automatically created from decorated functions
features_task = create_ml_features()
training_task = train_model_features()
features_task >> training_taskfrom airflow.providers.snowflake.decorators.snowpark import snowpark_task
from snowflake.snowpark.functions import col, sum as spark_sum, count, avg, max as spark_max
from snowflake.snowpark.types import StructType, StructField, StringType, IntegerType, DoubleType
@snowpark_task(
snowflake_conn_id='snowflake_prod',
warehouse='ETL_WH',
multiple_outputs=True
)
def comprehensive_etl_process(session, **context):
"""
Comprehensive ETL process using Snowpark DataFrame API.
"""
execution_date = context['ds']
# 1. Data Quality Checks
raw_data = session.table("raw.transaction_stream")
quality_metrics = {
'total_records': raw_data.count(),
'null_customer_ids': raw_data.filter(col("customer_id").is_null()).count(),
'invalid_amounts': raw_data.filter(col("amount") <= 0).count()
}
# 2. Data Cleaning and Transformation
clean_data = (
raw_data
.filter(col("customer_id").is_not_null())
.filter(col("amount") > 0)
.filter(col("transaction_date") >= execution_date)
.with_column("amount_category",
when(col("amount") >= 1000, "high")
.when(col("amount") >= 100, "medium")
.otherwise("low"))
)
# 3. Aggregation and Business Logic
customer_summary = (
clean_data
.group_by("customer_id", "amount_category")
.agg({
"amount": "sum",
"transaction_id": "count"
})
.with_column_renamed("SUM(AMOUNT)", "total_spent")
.with_column_renamed("COUNT(TRANSACTION_ID)", "transaction_count")
)
# 4. Advanced Analytics
pivot_summary = customer_summary.pivot(
"amount_category",
["high", "medium", "low"]
).agg({
"total_spent": "sum",
"transaction_count": "sum"
})
# 5. Write Results to Multiple Tables
# Clean transactional data
clean_data.write.save_as_table(
f"staging.clean_transactions_{execution_date.replace('-', '_')}",
mode="overwrite"
)
# Customer summaries
customer_summary.write.save_as_table(
"analytics.customer_transaction_summary",
mode="append"
)
# Pivot analysis
pivot_summary.write.save_as_table(
"analytics.spending_category_analysis",
mode="append"
)
# Return comprehensive metrics
return {
'quality_metrics': quality_metrics,
'processed_customers': customer_summary.select("customer_id").distinct().count(),
'clean_records': clean_data.count(),
'summary_records': customer_summary.count()
}
@snowpark_task(
snowflake_conn_id='snowflake_prod',
warehouse='ANALYTICS_WH'
)
def generate_business_reports(session, processed_metrics, **context):
"""
Generate business reports using processed data.
Args:
processed_metrics: Output from previous Snowpark task
"""
execution_date = context['ds']
# Create executive summary report
summary_data = session.sql(f"""
SELECT
'{execution_date}' as report_date,
COUNT(DISTINCT customer_id) as active_customers,
SUM(total_spent) as total_revenue,
AVG(total_spent) as avg_customer_spend,
SUM(transaction_count) as total_transactions
FROM analytics.customer_transaction_summary
WHERE DATE(created_at) = '{execution_date}'
""")
# Save executive dashboard data
summary_data.write.save_as_table(
"reports.daily_executive_summary",
mode="append"
)
return {
'report_generated': True,
'input_metrics': processed_metrics
}@snowpark_task(
snowflake_conn_id='snowflake_prod',
warehouse='HEAVY_COMPUTE_WH',
session_parameters={
'QUERY_TAG': 'airflow_snowpark_processing',
'MULTI_STATEMENT_COUNT': 5,
'AUTOCOMMIT': True
}
)
def advanced_snowpark_processing(session, **context):
"""
Advanced Snowpark processing with custom session configuration.
"""
# Enable query profiling
session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect()
# Use Snowpark ML functions (if available)
try:
from snowflake.ml.functions import detect_anomalies
# Load time series data
ts_data = session.table("analytics.daily_metrics")
# Detect anomalies using ML functions
anomaly_results = ts_data.select(
"*",
detect_anomalies(col("metric_value")).over(
partition_by=col("metric_type"),
order_by=col("date")
).alias("is_anomaly")
)
# Save anomaly detection results
anomaly_results.write.save_as_table(
"ml.anomaly_detection_results",
mode="append"
)
except ImportError:
# Fallback to statistical anomaly detection
session.sql("""
CREATE OR REPLACE TABLE ml.anomaly_detection_results AS
SELECT *,
CASE WHEN ABS(metric_value - AVG(metric_value) OVER (
PARTITION BY metric_type
ORDER BY date
ROWS BETWEEN 7 PRECEDING AND 1 PRECEDING
)) > 2 * STDDEV(metric_value) OVER (
PARTITION BY metric_type
ORDER BY date
ROWS BETWEEN 7 PRECEDING AND 1 PRECEDING
) THEN TRUE ELSE FALSE END as is_anomaly
FROM analytics.daily_metrics
""").collect()
return {"anomaly_detection_completed": True}The Snowpark integration automatically handles:
Snowpark provides a rich DataFrame API for data processing:
session.table(): Load existing tablessession.sql(): Execute SQL and return DataFramesession.read.options().csv(): Read from filesfilter(): Filter rows based on conditionsselect(): Select specific columnsgroup_by().agg(): Grouping and aggregationjoin(): Join operations between DataFramespivot(): Pivot table operationswith_column(): Add computed columnscollect(): Materialize DataFrame resultscount(): Count rowswrite.save_as_table(): Save to Snowflake tableshow(): Display sample dataSnowpark integration provides comprehensive error handling:
All errors include detailed stack traces and Snowflake-specific error information for troubleshooting.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-snowflake