Comprehensive Databricks integration for Apache Airflow with operators, hooks, sensors, and triggers for orchestrating data workflows
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The Databricks provider offers comprehensive monitoring and sensing capabilities through sensors and triggers that can monitor job completion, data availability, SQL query results, and system status with support for both synchronous and asynchronous (deferrable) execution patterns.
Monitor Databricks job run completion and status with configurable polling and error handling.
from airflow.providers.databricks.sensors.databricks import DatabricksSensor
class DatabricksSensor(BaseSensorOperator):
def __init__(
self,
run_id: int | str,
*,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[str, Any] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs
) -> None:
"""
Sensor for monitoring Databricks job run completion.
Args:
run_id: Databricks run ID to monitor (supports templating)
databricks_conn_id: Airflow connection ID for Databricks
polling_period_seconds: Seconds between status checks
databricks_retry_limit: Number of retries for API calls
databricks_retry_delay: Seconds between API call retries
databricks_retry_args: Additional retry configuration
deferrable: Whether to use deferrable (async) execution
"""Monitor SQL query results and data conditions on Databricks SQL endpoints.
from airflow.providers.databricks.sensors.databricks_sql import DatabricksSqlSensor
class DatabricksSqlSensor(BaseSensorOperator):
def __init__(
self,
sql: str,
*,
databricks_conn_id: str = "databricks_default",
http_path: str | None = None,
sql_endpoint_name: str | None = None,
session_configuration: dict[str, str] | None = None,
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
**kwargs
) -> None:
"""
Sensor for monitoring SQL query results on Databricks.
Args:
sql: SQL query to execute for monitoring (supports templating)
databricks_conn_id: Airflow connection ID for Databricks
http_path: HTTP path to SQL endpoint or cluster
sql_endpoint_name: Name of SQL endpoint to use
session_configuration: Session-level configuration parameters
http_headers: Additional HTTP headers for requests
catalog: Default catalog for SQL operations
schema: Default schema for SQL operations
"""Monitor table partition availability for data pipeline orchestration.
from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor
class DatabricksPartitionSensor(BaseSensorOperator):
def __init__(
self,
table_name: str,
partitions: dict[str, str] | list[dict[str, str]],
*,
databricks_conn_id: str = "databricks_default",
http_path: str | None = None,
sql_endpoint_name: str | None = None,
catalog: str | None = None,
schema: str | None = None,
**kwargs
) -> None:
"""
Sensor for monitoring table partition availability.
Args:
table_name: Name of table to monitor (supports templating)
partitions: Partition specifications to check for availability
databricks_conn_id: Airflow connection ID for Databricks
http_path: HTTP path to SQL endpoint or cluster
sql_endpoint_name: Name of SQL endpoint to use
catalog: Catalog containing the table
schema: Schema containing the table
"""Asynchronous trigger for deferrable job monitoring with efficient resource usage.
from airflow.providers.databricks.triggers.databricks import DatabricksTrigger
class DatabricksTrigger(BaseTrigger):
def __init__(
self,
run_id: int,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
**kwargs
) -> None:
"""
Async trigger for monitoring Databricks job runs.
Args:
run_id: Databricks run ID to monitor
databricks_conn_id: Airflow connection ID for Databricks
polling_period_seconds: Seconds between status checks
databricks_retry_limit: Number of retries for API calls
databricks_retry_delay: Seconds between API call retries
"""
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Async generator that yields trigger events.
Yields:
TriggerEvent with run completion status and metadata
"""Specialized trigger for monitoring Databricks workflow execution.
from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowTrigger
class DatabricksWorkflowTrigger(BaseTrigger):
def __init__(
self,
run_id: int,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
**kwargs
) -> None:
"""
Async trigger for monitoring Databricks workflow runs.
Args:
run_id: Databricks workflow run ID to monitor
databricks_conn_id: Airflow connection ID for Databricks
polling_period_seconds: Seconds between status checks
databricks_retry_limit: Number of retries for API calls
databricks_retry_delay: Seconds between API call retries
"""Monitor job completion with simple sensor configuration:
from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
from airflow.providers.databricks.sensors.databricks import DatabricksSensor
# Submit job and monitor completion
submit_job = DatabricksSubmitRunOperator(
task_id='submit_data_processing',
notebook_task={
'notebook_path': '/Analytics/Daily Processing',
'base_parameters': {'date': '{{ ds }}'}
},
existing_cluster_id='processing-cluster-001',
do_xcom_push=True
)
# Monitor job completion
monitor_job = DatabricksSensor(
task_id='wait_for_processing_completion',
run_id="{{ task_instance.xcom_pull(task_ids='submit_data_processing', key='run_id') }}",
databricks_conn_id='databricks_default',
poke_interval=60, # Check every minute
timeout=7200 # Timeout after 2 hours
)
submit_job >> monitor_jobUse deferrable execution for efficient resource utilization:
# Long-running job with deferrable monitoring
long_job = DatabricksSubmitRunOperator(
task_id='submit_ml_training',
spark_python_task={
'python_file': 'dbfs:/ml/training/train_model.py',
'parameters': ['--epochs', '500', '--data-size', 'large']
},
new_cluster={
'spark_version': '12.2.x-cpu-ml-scala2.12',
'node_type_id': 'i3.4xlarge',
'num_workers': 10
},
timeout_seconds=28800, # 8 hours
deferrable=True # Use deferrable execution
)
# Deferrable sensor - doesn't consume worker slot while waiting
deferrable_monitor = DatabricksSensor(
task_id='monitor_ml_training',
run_id="{{ task_instance.xcom_pull(task_ids='submit_ml_training', key='run_id') }}",
databricks_conn_id='databricks_ml',
polling_period_seconds=300, # Check every 5 minutes
timeout=28800, # 8 hour timeout
deferrable=True # Async monitoring
)
long_job >> deferrable_monitorMonitor data availability and quality using SQL sensors:
from airflow.providers.databricks.sensors.databricks_sql import DatabricksSqlSensor
# Wait for daily data to arrive
data_availability_sensor = DatabricksSqlSensor(
task_id='wait_for_daily_data',
sql="""
SELECT COUNT(*) as record_count
FROM raw.daily_transactions
WHERE transaction_date = '{{ ds }}'
HAVING COUNT(*) >= 10000
""",
databricks_conn_id='databricks_sql',
http_path='/sql/1.0/warehouses/analytics-warehouse',
poke_interval=300, # Check every 5 minutes
timeout=14400, # Wait up to 4 hours
catalog='production',
schema='raw'
)
# Monitor data quality thresholds
quality_sensor = DatabricksSqlSensor(
task_id='check_data_quality',
sql="""
SELECT
COUNT(*) as total_records,
SUM(CASE WHEN customer_id IS NOT NULL THEN 1 ELSE 0 END) as valid_customers,
SUM(CASE WHEN amount > 0 THEN 1 ELSE 0 END) as valid_amounts
FROM processed.daily_sales
WHERE processing_date = '{{ ds }}'
HAVING
(valid_customers * 100.0 / total_records) >= 95
AND (valid_amounts * 100.0 / total_records) >= 98
AND total_records >= 1000
""",
databricks_conn_id='databricks_sql',
poke_interval=180,
timeout=3600
)
data_availability_sensor >> quality_sensorMonitor table partition availability for data pipeline coordination:
from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor
# Wait for specific date partition
partition_sensor = DatabricksPartitionSensor(
task_id='wait_for_daily_partition',
table_name='sales.daily_transactions',
partitions={'date': '{{ ds }}'},
databricks_conn_id='databricks_sql',
catalog='production',
schema='sales',
poke_interval=600, # Check every 10 minutes
timeout=21600 # Wait up to 6 hours
)
# Wait for multiple partitions
multi_partition_sensor = DatabricksPartitionSensor(
task_id='wait_for_regional_partitions',
table_name='analytics.regional_metrics',
partitions=[
{'date': '{{ ds }}', 'region': 'north_america'},
{'date': '{{ ds }}', 'region': 'europe'},
{'date': '{{ ds }}', 'region': 'asia_pacific'}
],
databricks_conn_id='databricks_analytics',
poke_interval=300,
timeout=7200
)Implement conditional workflows based on data monitoring results:
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator
def check_data_completeness(**context):
"""Check data completeness and branch accordingly."""
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
hook = DatabricksSqlHook(databricks_conn_id='databricks_sql')
# Check data completeness
result = hook.get_first("""
SELECT
COUNT(*) as record_count,
COUNT(DISTINCT source_system) as source_count
FROM raw.daily_ingestion
WHERE ingestion_date = '{{ ds }}'
""")
record_count = result[0] if result else 0
source_count = result[1] if result else 0
# Expected: 3 source systems, minimum 50000 records
if source_count >= 3 and record_count >= 50000:
return 'full_processing'
elif record_count >= 10000:
return 'partial_processing'
else:
return 'wait_longer'
# Branching based on data status
data_check = BranchPythonOperator(
task_id='check_data_status',
python_callable=check_data_completeness
)
full_processing = DatabricksSubmitRunOperator(
task_id='full_processing',
notebook_task={
'notebook_path': '/pipelines/full_daily_pipeline'
},
existing_cluster_id='large-cluster-001'
)
partial_processing = DatabricksSubmitRunOperator(
task_id='partial_processing',
notebook_task={
'notebook_path': '/pipelines/partial_daily_pipeline'
},
existing_cluster_id='small-cluster-001'
)
wait_longer = DummyOperator(task_id='wait_longer')
data_check >> [full_processing, partial_processing, wait_longer]Implement cascading monitors for complex data dependencies:
from airflow.utils.task_group import TaskGroup
with TaskGroup(group_id='data_dependency_monitoring') as monitoring_group:
# Level 1: Raw data availability
raw_data_monitor = DatabricksSqlSensor(
task_id='monitor_raw_data',
sql="""
SELECT 1
FROM information_schema.tables
WHERE table_name = 'raw_events_{{ ds_nodash }}'
AND table_schema = 'landing'
""",
databricks_conn_id='databricks_sql',
poke_interval=120,
timeout=7200
)
# Level 2: Data processing completion
processing_monitor = DatabricksSqlSensor(
task_id='monitor_processing_completion',
sql="""
SELECT 1
FROM processing_status
WHERE process_date = '{{ ds }}'
AND status = 'COMPLETED'
AND error_count = 0
""",
databricks_conn_id='databricks_sql',
poke_interval=300,
timeout=10800
)
# Level 3: Quality validation
quality_monitor = DatabricksSqlSensor(
task_id='monitor_quality_validation',
sql="""
SELECT 1
FROM quality_metrics
WHERE validation_date = '{{ ds }}'
AND overall_score >= 0.95
AND critical_failures = 0
""",
databricks_conn_id='databricks_sql',
poke_interval=180,
timeout=3600
)
# Set up monitoring cascade
raw_data_monitor >> processing_monitor >> quality_monitorMonitor streaming data pipelines and real-time processing:
# Monitor streaming job health
streaming_monitor = DatabricksSqlSensor(
task_id='monitor_streaming_health',
sql="""
SELECT
stream_id,
batch_duration_ms,
input_size,
processing_time_ms
FROM streaming_metrics
WHERE
metric_timestamp >= CURRENT_TIMESTAMP - INTERVAL 5 MINUTES
AND batch_duration_ms > 0
AND processing_time_ms < batch_duration_ms * 0.8 -- Processing within 80% of batch interval
HAVING COUNT(*) >= 3 -- At least 3 healthy batches in last 5 minutes
""",
databricks_conn_id='databricks_streaming',
poke_interval=60, # Check every minute
timeout=1800, # 30 minute timeout
mode='reschedule' # Don't block worker
)
# Monitor streaming lag
lag_monitor = DatabricksSqlSensor(
task_id='monitor_streaming_lag',
sql="""
SELECT 1
FROM (
SELECT MAX(event_timestamp) as latest_processed
FROM processed_events
) processed
CROSS JOIN (
SELECT CURRENT_TIMESTAMP as current_time
) current
WHERE TIMESTAMPDIFF(MINUTE, latest_processed, current_time) <= 10 -- Max 10 minutes lag
""",
databricks_conn_id='databricks_streaming',
poke_interval=300,
timeout=3600
)Implement monitoring with error detection and alerting:
def monitor_with_alerting(**context):
"""Monitor job with custom error handling and alerting."""
from airflow.providers.databricks.hooks.databricks import DatabricksHook
run_id = context['ti'].xcom_pull(task_ids='submit_critical_job', key='run_id')
hook = DatabricksHook(databricks_conn_id='databricks_production')
import time
timeout = 7200 # 2 hours
start_time = time.time()
while time.time() - start_time < timeout:
run_state = hook.get_run_state(run_id)
if run_state.is_terminal:
if run_state.is_successful:
print(f"Job {run_id} completed successfully")
return True
else:
# Job failed - extract error details
run_details = hook.get_run(run_id)
error_message = run_details.get('state', {}).get('state_message', 'Unknown error')
# Send alert (integrate with your alerting system)
send_alert(
message=f"Critical Databricks job {run_id} failed: {error_message}",
severity='HIGH',
job_url=hook.get_run_page_url(run_id)
)
raise ValueError(f"Databricks job {run_id} failed: {error_message}")
time.sleep(60) # Check every minute
# Timeout occurred
send_alert(
message=f"Databricks job {run_id} timed out after {timeout} seconds",
severity='MEDIUM',
job_url=hook.get_run_page_url(run_id)
)
raise TimeoutError(f"Job monitoring timed out for run {run_id}")
def send_alert(message: str, severity: str, job_url: str):
"""Send alert through configured alerting system."""
# Implement your alerting logic here
# (Slack, email, PagerDuty, etc.)
print(f"ALERT [{severity}]: {message}")
print(f"Job URL: {job_url}")
# Custom monitoring with alerting
custom_monitor = PythonOperator(
task_id='monitor_with_alerts',
python_callable=monitor_with_alerting
)Monitor job performance and resource utilization:
performance_monitor = DatabricksSqlSensor(
task_id='monitor_job_performance',
sql="""
WITH job_metrics AS (
SELECT
run_id,
execution_duration_ms,
cluster_size,
shuffle_read_bytes,
shuffle_write_bytes,
peak_memory_usage
FROM job_execution_metrics
WHERE job_name = '{{ params.job_name }}'
AND start_time >= CURRENT_DATE
)
SELECT 1
FROM job_metrics
WHERE
execution_duration_ms < {{ params.max_duration_ms }}
AND peak_memory_usage < {{ params.max_memory_bytes }}
AND shuffle_read_bytes < {{ params.max_shuffle_bytes }}
ORDER BY run_id DESC
LIMIT 1
""",
params={
'job_name': 'daily_etl_pipeline',
'max_duration_ms': 7200000, # 2 hours
'max_memory_bytes': 32 * 1024**3, # 32GB
'max_shuffle_bytes': 100 * 1024**3 # 100GB
},
databricks_conn_id='databricks_metrics',
poke_interval=300,
timeout=3600
)The monitoring and sensing capabilities provide comprehensive tools for tracking job execution, data availability, quality metrics, and system health with both synchronous and asynchronous execution patterns to optimize resource usage and provide timely notifications of pipeline status.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-databricks