Provider package for Apache Spark integration with Apache Airflow, offering operators, hooks, sensors, and decorators for distributed data processing workflows.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Task decorators for seamless PySpark integration within Airflow workflows. These decorators automatically inject Spark session objects into Python functions, enabling native PySpark development patterns with automatic session management and cleanup.
Transform Python functions into Airflow tasks that automatically receive a configured Spark session, enabling seamless integration of PySpark code within Airflow workflows.
def pyspark_task(
python_callable: Callable | None = None,
multiple_outputs: bool | None = None,
**kwargs,
) -> TaskDecorator:
"""
Decorator for creating PySpark tasks with automatic Spark session injection.
The decorated function automatically receives a SparkSession as its first parameter
(typically named 'spark'). The session is created, configured, and cleaned up
automatically by the decorator.
Parameters:
- python_callable (Callable, optional): Python function to decorate
- multiple_outputs (bool, optional): Whether function returns multiple outputs
- **kwargs: Additional arguments passed to underlying DecoratedOperator
Returns:
TaskDecorator: Configured task decorator for PySpark functions
Usage:
@pyspark_task
def my_function(spark):
# spark is automatically injected SparkSession
df = spark.read.parquet('/path/to/data')
return df.count()
"""Internal operator class that handles PySpark task execution with automatic session management.
class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
"""
Internal decorated operator for PySpark tasks.
Inherits from both DecoratedOperator and PythonOperator.
This class is created automatically by the @pyspark_task decorator
and should not be instantiated directly.
Template Fields: op_args, op_kwargs
Key Features:
- Automatic SparkSession creation and injection
- Session configuration from Airflow connections
- Proper session cleanup after task execution
- Integration with Airflow's task execution framework
"""
def __init__(
self,
python_callable: Callable,
op_args: Sequence | None = None,
op_kwargs: dict | None = None,
conn_id: str | None = None,
config_kwargs: dict | None = None,
**kwargs
): ...
def execute(self, context) -> Any:
"""
Execute PySpark function with injected Spark session.
Parameters:
- context: Airflow task execution context
Returns:
Any: Result returned by the decorated function
"""from airflow.providers.apache.spark.decorators.pyspark import pyspark_task
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, avg
@pyspark_task(task_id='analyze_user_data')
def analyze_users(spark):
"""
Analyze user data using PySpark.
Args:
spark (SparkSession): Automatically injected Spark session
Returns:
dict: Analysis results
"""
# Read data from various sources
users_df = spark.read.parquet('/data/users.parquet')
transactions_df = spark.read.json('/data/transactions.json')
# Perform analysis
user_stats = users_df.groupBy('region') \
.agg(
count('user_id').alias('user_count'),
avg('age').alias('avg_age')
)
# Join with transaction data
enriched_stats = user_stats.join(
transactions_df.groupBy('region').agg(
count('transaction_id').alias('transaction_count')
),
on='region'
)
# Collect results
results = enriched_stats.collect()
# Write results
enriched_stats.write \
.mode('overwrite') \
.parquet('/data/user_analysis_results.parquet')
return {
'regions_analyzed': len(results),
'total_users': sum(row['user_count'] for row in results)
}
# Use in DAG
user_analysis_task = analyze_users()from airflow.providers.apache.spark.decorators.pyspark import pyspark_task
from airflow import DAG
from datetime import datetime, timedelta
dag = DAG(
'pyspark_etl',
default_args={
'owner': 'data-team',
'retries': 2,
'retry_delay': timedelta(minutes=5),
},
description='PySpark ETL workflow',
schedule_interval='@daily',
start_date=datetime(2023, 1, 1),
catchup=False,
)
@pyspark_task(
task_id='process_daily_data',
conn_id='spark_cluster', # Use specific Spark connection
dag=dag
)
def process_daily_data(spark, **context):
"""
Process daily data with date templating.
Args:
spark (SparkSession): Spark session
**context: Airflow context with templated variables
"""
execution_date = context['ds'] # YYYY-MM-DD format
# Read data for specific date
daily_data = spark.read.parquet(f'/data/raw/{execution_date}/*.parquet')
# Data processing
processed_data = daily_data \
.filter(col('status') == 'active') \
.withColumn('processing_date', lit(execution_date)) \
.groupBy('category', 'processing_date') \
.agg(
count('id').alias('record_count'),
sum('amount').alias('total_amount'),
avg('amount').alias('avg_amount')
)
# Write processed data
output_path = f'/data/processed/{execution_date}'
processed_data.write \
.mode('overwrite') \
.partitionBy('processing_date') \
.parquet(output_path)
return {
'execution_date': execution_date,
'output_path': output_path,
'record_count': processed_data.count()
}
process_task = process_daily_data()@pyspark_task(
task_id='generate_multiple_reports',
multiple_outputs=True # Enable multiple outputs
)
def generate_reports(spark):
"""
Generate multiple reports from data analysis.
Returns:
dict: Multiple named outputs for downstream tasks
"""
# Load source data
sales_df = spark.read.table('sales')
customers_df = spark.read.table('customers')
# Generate sales report
sales_summary = sales_df.groupBy('region', 'product_category') \
.agg(
sum('amount').alias('total_sales'),
count('order_id').alias('order_count')
)
sales_report_path = '/reports/sales_summary.parquet'
sales_summary.write.mode('overwrite').parquet(sales_report_path)
# Generate customer report
customer_summary = customers_df.groupBy('region', 'segment') \
.agg(
count('customer_id').alias('customer_count'),
avg('lifetime_value').alias('avg_ltv')
)
customer_report_path = '/reports/customer_summary.parquet'
customer_summary.write.mode('overwrite').parquet(customer_report_path)
# Return multiple outputs
return {
'sales_report': {
'path': sales_report_path,
'record_count': sales_summary.count()
},
'customer_report': {
'path': customer_report_path,
'record_count': customer_summary.count()
},
'execution_summary': {
'total_sales': sales_df.agg(sum('amount')).collect()[0][0],
'total_customers': customers_df.count()
}
}
# Use multiple outputs in downstream tasks
reports_task = generate_reports()
@task
def process_sales_report(sales_data):
print(f"Processing sales report at: {sales_data['path']}")
print(f"Records processed: {sales_data['record_count']}")
@task
def process_customer_report(customer_data):
print(f"Processing customer report at: {customer_data['path']}")
print(f"Records processed: {customer_data['record_count']}")
# Access individual outputs
process_sales_report(reports_task['sales_report'])
process_customer_report(reports_task['customer_report'])@pyspark_task(
task_id='ml_feature_engineering',
packages=['org.apache.spark:spark-mllib_2.12:3.5.0'], # MLlib dependency
)
def feature_engineering(spark):
"""
Perform feature engineering using Spark MLlib.
"""
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml import Pipeline
# Load raw data
raw_data = spark.read.csv('/data/raw_features.csv', header=True, inferSchema=True)
# Define feature engineering pipeline
string_indexer = StringIndexer(
inputCol='category',
outputCol='category_indexed'
)
vector_assembler = VectorAssembler(
inputCols=['feature1', 'feature2', 'feature3', 'category_indexed'],
outputCol='features_raw'
)
scaler = StandardScaler(
inputCol='features_raw',
outputCol='features_scaled'
)
# Create and fit pipeline
pipeline = Pipeline(stages=[string_indexer, vector_assembler, scaler])
pipeline_model = pipeline.fit(raw_data)
# Transform data
processed_data = pipeline_model.transform(raw_data)
# Save processed features
processed_data.select('id', 'features_scaled', 'target') \
.write.mode('overwrite') \
.parquet('/data/processed_features.parquet')
return {
'input_records': raw_data.count(),
'output_records': processed_data.count(),
'feature_count': len(vector_assembler.getInputCols())
}
ml_task = feature_engineering()The decorator automatically configures the Spark session based on:
conn_id parameter# Automatic session lifecycle management
@pyspark_task
def my_task(spark):
# Session is created and configured automatically
# No need to call spark.stop() - handled by decorator
df = spark.createDataFrame([...])
return df.count()
# Session is automatically stopped after function completes@pyspark_task
def robust_pyspark_task(spark):
try:
# PySpark operations
result = spark.sql("SELECT COUNT(*) FROM large_table")
return result.collect()[0][0]
except Exception as e:
# Log error details
print(f"PySpark task failed: {str(e)}")
# Session cleanup handled automatically
raise # Re-raise for Airflow error handling@pyspark_task(
task_id='memory_intensive_task',
# Configure resources through connection or task parameters
)
def memory_intensive_processing(spark):
# Configure Spark session for memory-intensive operations
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
# Process large datasets
large_df = spark.read.parquet('/data/very_large_dataset.parquet')
# Use appropriate partitioning
result = large_df.repartition(200) \
.groupBy('partition_key') \
.agg(various_aggregations) \
.cache() # Cache intermediate results
return result.count()@pyspark_task
def produce_data(spark):
# Process data and return summary for downstream tasks
df = spark.read.table('source_data')
summary = {
'record_count': df.count(),
'output_path': '/data/processed/{{ ds }}'
}
return summary # Automatically stored in XCom
@task
def consume_data(summary_data):
# Access PySpark task results
print(f"Processing {summary_data['record_count']} records")
print(f"Data location: {summary_data['output_path']}")
# Chain tasks
summary = produce_data()
consume_data(summary)@pyspark_task(
task_id='templated_spark_task',
# Template fields work with decorator
)
def templated_processing(spark, **context):
# Access Airflow templating
execution_date = context['ds']
dag_run_id = context['dag_run'].run_id
# Use in PySpark operations
df = spark.read.parquet(f'/data/{execution_date}/*.parquet')
# Add metadata columns
result = df.withColumn('dag_run_id', lit(dag_run_id)) \
.withColumn('processing_date', lit(execution_date))
return result.count()SPARK_CONTEXT_KEYS = ["spark", "sc"]These keys define the parameter names that will receive the Spark session or context objects in decorated functions. The decorator automatically injects the Spark session using the first available parameter name from this list.
@pyspark_task(
task_id='debug_task',
conf={
'spark.sql.execution.arrow.pyspark.enabled': 'true',
'spark.sql.adaptive.enabled': 'true',
'spark.eventLog.enabled': 'true',
'spark.eventLog.dir': '/tmp/spark-events'
}
)
def debug_processing(spark):
# Enable debug logging
spark.sparkContext.setLogLevel("DEBUG")
# Your PySpark code with detailed logging
df = spark.read.parquet('/data/input.parquet')
print(f"Input partitions: {df.rdd.getNumPartitions()}")
print(f"Input schema: {df.schema}")
result = df.groupBy('key').count()
print(f"Result count: {result.count()}")
return result.collect()Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-apache-spark