CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-apache-airflow-providers-apache-spark

Provider package for Apache Spark integration with Apache Airflow, offering operators, hooks, sensors, and decorators for distributed data processing workflows.

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

pyspark-decorators.mddocs/

PySpark Decorators

Task decorators for seamless PySpark integration within Airflow workflows. These decorators automatically inject Spark session objects into Python functions, enabling native PySpark development patterns with automatic session management and cleanup.

Capabilities

PySpark Task Decorator

Transform Python functions into Airflow tasks that automatically receive a configured Spark session, enabling seamless integration of PySpark code within Airflow workflows.

def pyspark_task(
    python_callable: Callable | None = None,
    multiple_outputs: bool | None = None,
    **kwargs,
) -> TaskDecorator:
    """
    Decorator for creating PySpark tasks with automatic Spark session injection.
    
    The decorated function automatically receives a SparkSession as its first parameter
    (typically named 'spark'). The session is created, configured, and cleaned up
    automatically by the decorator.
    
    Parameters:
    - python_callable (Callable, optional): Python function to decorate
    - multiple_outputs (bool, optional): Whether function returns multiple outputs
    - **kwargs: Additional arguments passed to underlying DecoratedOperator
    
    Returns:
    TaskDecorator: Configured task decorator for PySpark functions
    
    Usage:
    @pyspark_task
    def my_function(spark):
        # spark is automatically injected SparkSession
        df = spark.read.parquet('/path/to/data')
        return df.count()
    """

PySpark Decorated Operator

Internal operator class that handles PySpark task execution with automatic session management.

class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
    """
    Internal decorated operator for PySpark tasks.
    Inherits from both DecoratedOperator and PythonOperator.
    
    This class is created automatically by the @pyspark_task decorator
    and should not be instantiated directly.
    
    Template Fields: op_args, op_kwargs
    
    Key Features:
    - Automatic SparkSession creation and injection
    - Session configuration from Airflow connections
    - Proper session cleanup after task execution
    - Integration with Airflow's task execution framework
    """
    
    def __init__(
        self,
        python_callable: Callable,
        op_args: Sequence | None = None,
        op_kwargs: dict | None = None,
        conn_id: str | None = None,
        config_kwargs: dict | None = None,
        **kwargs
    ): ...
    
    def execute(self, context) -> Any:
        """
        Execute PySpark function with injected Spark session.
        
        Parameters:
        - context: Airflow task execution context
        
        Returns:
        Any: Result returned by the decorated function
        """

Usage Examples

Basic PySpark Task

from airflow.providers.apache.spark.decorators.pyspark import pyspark_task
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, avg

@pyspark_task(task_id='analyze_user_data')
def analyze_users(spark):
    """
    Analyze user data using PySpark.
    
    Args:
        spark (SparkSession): Automatically injected Spark session
        
    Returns:
        dict: Analysis results
    """
    # Read data from various sources
    users_df = spark.read.parquet('/data/users.parquet')
    transactions_df = spark.read.json('/data/transactions.json')
    
    # Perform analysis
    user_stats = users_df.groupBy('region') \
        .agg(
            count('user_id').alias('user_count'),
            avg('age').alias('avg_age')
        )
    
    # Join with transaction data
    enriched_stats = user_stats.join(
        transactions_df.groupBy('region').agg(
            count('transaction_id').alias('transaction_count')
        ),
        on='region'
    )
    
    # Collect results
    results = enriched_stats.collect()
    
    # Write results
    enriched_stats.write \
        .mode('overwrite') \
        .parquet('/data/user_analysis_results.parquet')
    
    return {
        'regions_analyzed': len(results),
        'total_users': sum(row['user_count'] for row in results)
    }

# Use in DAG
user_analysis_task = analyze_users()

PySpark Task with Configuration

from airflow.providers.apache.spark.decorators.pyspark import pyspark_task
from airflow import DAG
from datetime import datetime, timedelta

dag = DAG(
    'pyspark_etl',
    default_args={
        'owner': 'data-team',
        'retries': 2,
        'retry_delay': timedelta(minutes=5),
    },
    description='PySpark ETL workflow',
    schedule_interval='@daily',
    start_date=datetime(2023, 1, 1),
    catchup=False,
)

@pyspark_task(
    task_id='process_daily_data',
    conn_id='spark_cluster',  # Use specific Spark connection
    dag=dag
)
def process_daily_data(spark, **context):
    """
    Process daily data with date templating.
    
    Args:
        spark (SparkSession): Spark session
        **context: Airflow context with templated variables
    """
    execution_date = context['ds']  # YYYY-MM-DD format
    
    # Read data for specific date
    daily_data = spark.read.parquet(f'/data/raw/{execution_date}/*.parquet')
    
    # Data processing
    processed_data = daily_data \
        .filter(col('status') == 'active') \
        .withColumn('processing_date', lit(execution_date)) \
        .groupBy('category', 'processing_date') \
        .agg(
            count('id').alias('record_count'),
            sum('amount').alias('total_amount'),
            avg('amount').alias('avg_amount')
        )
    
    # Write processed data
    output_path = f'/data/processed/{execution_date}'
    processed_data.write \
        .mode('overwrite') \
        .partitionBy('processing_date') \
        .parquet(output_path)
    
    return {
        'execution_date': execution_date,
        'output_path': output_path,
        'record_count': processed_data.count()
    }

process_task = process_daily_data()

Multiple Output PySpark Task

@pyspark_task(
    task_id='generate_multiple_reports',
    multiple_outputs=True  # Enable multiple outputs
)
def generate_reports(spark):
    """
    Generate multiple reports from data analysis.
    
    Returns:
        dict: Multiple named outputs for downstream tasks
    """
    # Load source data
    sales_df = spark.read.table('sales')
    customers_df = spark.read.table('customers')
    
    # Generate sales report
    sales_summary = sales_df.groupBy('region', 'product_category') \
        .agg(
            sum('amount').alias('total_sales'),
            count('order_id').alias('order_count')
        )
    
    sales_report_path = '/reports/sales_summary.parquet'
    sales_summary.write.mode('overwrite').parquet(sales_report_path)
    
    # Generate customer report
    customer_summary = customers_df.groupBy('region', 'segment') \
        .agg(
            count('customer_id').alias('customer_count'),
            avg('lifetime_value').alias('avg_ltv')
        )
    
    customer_report_path = '/reports/customer_summary.parquet'
    customer_summary.write.mode('overwrite').parquet(customer_report_path)
    
    # Return multiple outputs
    return {
        'sales_report': {
            'path': sales_report_path,
            'record_count': sales_summary.count()
        },
        'customer_report': {
            'path': customer_report_path,
            'record_count': customer_summary.count()
        },
        'execution_summary': {
            'total_sales': sales_df.agg(sum('amount')).collect()[0][0],
            'total_customers': customers_df.count()
        }
    }

# Use multiple outputs in downstream tasks
reports_task = generate_reports()

@task
def process_sales_report(sales_data):
    print(f"Processing sales report at: {sales_data['path']}")
    print(f"Records processed: {sales_data['record_count']}")

@task  
def process_customer_report(customer_data):
    print(f"Processing customer report at: {customer_data['path']}")
    print(f"Records processed: {customer_data['record_count']}")

# Access individual outputs
process_sales_report(reports_task['sales_report'])
process_customer_report(reports_task['customer_report'])

PySpark Task with External Dependencies

@pyspark_task(
    task_id='ml_feature_engineering',
    packages=['org.apache.spark:spark-mllib_2.12:3.5.0'],  # MLlib dependency
)
def feature_engineering(spark):
    """
    Perform feature engineering using Spark MLlib.
    """
    from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
    from pyspark.ml import Pipeline
    
    # Load raw data
    raw_data = spark.read.csv('/data/raw_features.csv', header=True, inferSchema=True)
    
    # Define feature engineering pipeline
    string_indexer = StringIndexer(
        inputCol='category', 
        outputCol='category_indexed'
    )
    
    vector_assembler = VectorAssembler(
        inputCols=['feature1', 'feature2', 'feature3', 'category_indexed'],
        outputCol='features_raw'
    )
    
    scaler = StandardScaler(
        inputCol='features_raw',
        outputCol='features_scaled'
    )
    
    # Create and fit pipeline
    pipeline = Pipeline(stages=[string_indexer, vector_assembler, scaler])
    pipeline_model = pipeline.fit(raw_data)
    
    # Transform data
    processed_data = pipeline_model.transform(raw_data)
    
    # Save processed features
    processed_data.select('id', 'features_scaled', 'target') \
        .write.mode('overwrite') \
        .parquet('/data/processed_features.parquet')
    
    return {
        'input_records': raw_data.count(),
        'output_records': processed_data.count(),
        'feature_count': len(vector_assembler.getInputCols())
    }

ml_task = feature_engineering()

Configuration and Best Practices

Spark Session Configuration

The decorator automatically configures the Spark session based on:

  1. Airflow Connection: Uses connection settings from conn_id parameter
  2. Default Configuration: Applies sensible defaults for Airflow integration
  3. Environment Variables: Respects Spark environment configuration

Session Management

# Automatic session lifecycle management
@pyspark_task
def my_task(spark):
    # Session is created and configured automatically
    # No need to call spark.stop() - handled by decorator
    
    df = spark.createDataFrame([...])
    return df.count()
    # Session is automatically stopped after function completes

Error Handling

@pyspark_task
def robust_pyspark_task(spark):
    try:
        # PySpark operations
        result = spark.sql("SELECT COUNT(*) FROM large_table")
        return result.collect()[0][0]
        
    except Exception as e:
        # Log error details
        print(f"PySpark task failed: {str(e)}")
        # Session cleanup handled automatically
        raise  # Re-raise for Airflow error handling

Memory and Resource Management

@pyspark_task(
    task_id='memory_intensive_task',
    # Configure resources through connection or task parameters
)
def memory_intensive_processing(spark):
    # Configure Spark session for memory-intensive operations
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
    spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    
    # Process large datasets
    large_df = spark.read.parquet('/data/very_large_dataset.parquet')
    
    # Use appropriate partitioning
    result = large_df.repartition(200) \
        .groupBy('partition_key') \
        .agg(various_aggregations) \
        .cache()  # Cache intermediate results
    
    return result.count()

Integration with Other Airflow Features

XCom Integration

@pyspark_task
def produce_data(spark):
    # Process data and return summary for downstream tasks
    df = spark.read.table('source_data')
    summary = {
        'record_count': df.count(),
        'output_path': '/data/processed/{{ ds }}'
    }
    return summary  # Automatically stored in XCom

@task
def consume_data(summary_data):
    # Access PySpark task results
    print(f"Processing {summary_data['record_count']} records")
    print(f"Data location: {summary_data['output_path']}")

# Chain tasks
summary = produce_data()
consume_data(summary)

Templating Support

@pyspark_task(
    task_id='templated_spark_task',
    # Template fields work with decorator
)
def templated_processing(spark, **context):
    # Access Airflow templating
    execution_date = context['ds']
    dag_run_id = context['dag_run'].run_id
    
    # Use in PySpark operations
    df = spark.read.parquet(f'/data/{execution_date}/*.parquet')
    
    # Add metadata columns
    result = df.withColumn('dag_run_id', lit(dag_run_id)) \
              .withColumn('processing_date', lit(execution_date))
    
    return result.count()

Constants and Utilities

Spark Context Keys

SPARK_CONTEXT_KEYS = ["spark", "sc"]

These keys define the parameter names that will receive the Spark session or context objects in decorated functions. The decorator automatically injects the Spark session using the first available parameter name from this list.

Error Handling and Debugging

Common Issues

  1. Memory Errors: Configure appropriate executor memory and driver memory
  2. Serialization Errors: Use Kryo serializer for complex data types
  3. Connection Errors: Verify Spark cluster connectivity and authentication
  4. Dependency Conflicts: Ensure compatible PySpark and dependency versions

Debug Mode

@pyspark_task(
    task_id='debug_task',
    conf={
        'spark.sql.execution.arrow.pyspark.enabled': 'true',
        'spark.sql.adaptive.enabled': 'true',
        'spark.eventLog.enabled': 'true',
        'spark.eventLog.dir': '/tmp/spark-events'
    }
)
def debug_processing(spark):
    # Enable debug logging
    spark.sparkContext.setLogLevel("DEBUG")
    
    # Your PySpark code with detailed logging
    df = spark.read.parquet('/data/input.parquet')
    print(f"Input partitions: {df.rdd.getNumPartitions()}")
    print(f"Input schema: {df.schema}")
    
    result = df.groupBy('key').count()
    print(f"Result count: {result.count()}")
    
    return result.collect()

Install with Tessl CLI

npx tessl i tessl/pypi-apache-airflow-providers-apache-spark

docs

index.md

pyspark-decorators.md

spark-hooks.md

spark-operators.md

tile.json