Comprehensive Databricks integration for Apache Airflow with operators, hooks, sensors, and triggers for orchestrating data workflows
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
The Databricks provider offers flexible authentication and connection management through specialized hooks that support multiple authentication methods, connection pooling, and robust error handling for both REST API and SQL operations.
Primary hook for Databricks REST API operations with comprehensive authentication support.
from airflow.providers.databricks.hooks.databricks import DatabricksHook
class DatabricksHook(BaseDatabricksHook):
def __init__(
self,
databricks_conn_id: str = "databricks_default",
timeout_seconds: int | None = None,
retry_limit: int = 3,
retry_delay: int = 1,
retry_args: dict[str, Any] | None = None,
caller: str | None = None,
**kwargs
) -> None:
"""
Hook for interacting with Databricks REST API.
Args:
databricks_conn_id: Airflow connection ID for Databricks
timeout_seconds: Request timeout in seconds
retry_limit: Number of retries for failed requests
retry_delay: Base delay between retries in seconds
retry_args: Additional retry configuration (exponential backoff, etc.)
caller: Caller identification for logging and debugging
"""
def submit_run(self, json: dict[str, Any]) -> int:
"""
Submit a one-time run to Databricks.
Args:
json: Run configuration dictionary
Returns:
Run ID of the submitted job
"""
def run_now(self, json: dict[str, Any]) -> int:
"""
Trigger an existing Databricks job.
Args:
json: Job trigger configuration
Returns:
Run ID of the triggered job
"""
def get_run_state(self, run_id: int) -> RunState:
"""
Get current state of a Databricks run.
Args:
run_id: Run ID to check
Returns:
RunState object with current status
"""
def cancel_run(self, run_id: int) -> None:
"""
Cancel a running Databricks job.
Args:
run_id: Run ID to cancel
"""
def get_run_page_url(self, run_id: int) -> str:
"""
Get URL for the Databricks run page.
Args:
run_id: Run ID
Returns:
Direct URL to run details page
"""Specialized hook for SQL operations on Databricks SQL endpoints and clusters.
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
class DatabricksSqlHook(DbApiHook):
def __init__(
self,
databricks_conn_id: str = "databricks_default",
http_path: str | None = None,
session_configuration: dict[str, str] | None = None,
sql_endpoint_name: str | None = None,
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
caller: str | None = None,
**kwargs
) -> None:
"""
Hook for SQL operations on Databricks SQL endpoints.
Args:
databricks_conn_id: Airflow connection ID for Databricks
http_path: HTTP path to SQL endpoint or cluster
session_configuration: Session-level Spark configuration
sql_endpoint_name: Name of SQL endpoint to use
http_headers: Additional HTTP headers for requests
catalog: Default catalog for SQL operations
schema: Default schema for SQL operations
caller: Caller identification for logging
"""
def get_conn(self) -> Connection:
"""
Get database connection for SQL operations.
Returns:
Database connection object
"""
def run(
self,
sql: str | list[str],
autocommit: bool = False,
parameters: dict[str, Any] | None = None,
handler: Callable[[Any], Any] | None = None,
split_statements: bool = False,
return_last: bool = True
) -> Any:
"""
Execute SQL statement(s).
Args:
sql: SQL query or list of queries
autocommit: Whether to autocommit transactions
parameters: Parameters for parameterized queries
handler: Result handler function
split_statements: Whether to split multiple statements
return_last: Return only last result for multiple queries
Returns:
Query results based on handler or default processing
"""
def get_pandas_df(
self,
sql: str,
parameters: dict[str, Any] | None = None,
**kwargs
) -> DataFrame:
"""
Execute SQL query and return results as pandas DataFrame.
Args:
sql: SQL query to execute
parameters: Query parameters
Returns:
pandas DataFrame with query results
"""The most common authentication method using Databricks personal access tokens:
# Connection configuration in Airflow
# Connection ID: databricks_token_auth
# Connection Type: Databricks
# Host: https://your-databricks-workspace.cloud.databricks.com
# Password: dapi1234567890abcdef (your personal access token)
from airflow.providers.databricks.hooks.databricks import DatabricksHook
# Use hook with token authentication
hook = DatabricksHook(
databricks_conn_id='databricks_token_auth',
timeout_seconds=600,
retry_limit=3
)
# Submit job using authenticated connection
run_id = hook.submit_run({
'run_name': 'Token Auth Example',
'notebook_task': {
'notebook_path': '/Shared/example_notebook'
},
'existing_cluster_id': 'cluster-001'
})Authenticate using Azure AD for Azure Databricks workspaces:
# Connection configuration for Azure AD
# Connection ID: databricks_azure_ad
# Connection Type: Databricks
# Host: https://adb-1234567890123456.7.azuredatabricks.net
# Extra: {
# "azure_tenant_id": "12345678-1234-1234-1234-123456789012",
# "azure_client_id": "87654321-4321-4321-4321-210987654321",
# "azure_client_secret": "your_client_secret",
# "use_azure_cli": false
# }
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
# SQL hook with Azure AD authentication
sql_hook = DatabricksSqlHook(
databricks_conn_id='databricks_azure_ad',
http_path='/sql/1.0/warehouses/your-warehouse-id',
catalog='production',
schema='analytics'
)
# Execute query with Azure AD authentication
results = sql_hook.get_pandas_df("""
SELECT customer_id, SUM(order_amount) as total_spent
FROM orders
WHERE order_date >= CURRENT_DATE - INTERVAL 30 DAYS
GROUP BY customer_id
""")Use Azure service principals for programmatic access:
# Connection configuration for Service Principal
# Connection ID: databricks_service_principal
# Connection Type: Databricks
# Host: https://adb-1234567890123456.7.azuredatabricks.net
# Extra: {
# "azure_tenant_id": "12345678-1234-1234-1234-123456789012",
# "azure_client_id": "service-principal-client-id",
# "azure_client_secret": "service-principal-secret"
# }
hook = DatabricksHook(
databricks_conn_id='databricks_service_principal',
retry_limit=5,
retry_delay=2
)
# Create and run job with service principal auth
job_config = {
'name': 'Service Principal Job',
'new_cluster': {
'spark_version': '12.2.x-scala2.12',
'node_type_id': 'Standard_DS3_v2',
'num_workers': 2
},
'notebook_task': {
'notebook_path': '/Production/ETL/daily_pipeline'
},
'timeout_seconds': 3600
}
job_id = hook.create_job(job_config)
run_id = hook.run_now({'job_id': job_id})Authenticate using AWS IAM roles for AWS Databricks workspaces:
# Connection configuration for AWS IAM
# Connection ID: databricks_aws_iam
# Connection Type: Databricks
# Host: https://dbc-12345678-9012.cloud.databricks.com
# Extra: {
# "use_aws_iam_role": true,
# "aws_region": "us-west-2"
# }
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
# SQL operations with IAM role authentication
iam_hook = DatabricksSqlHook(
databricks_conn_id='databricks_aws_iam',
http_path='/sql/1.0/warehouses/warehouse-123',
session_configuration={
'spark.sql.adaptive.enabled': 'true',
'spark.sql.adaptive.coalescePartitions.enabled': 'true'
}
)
# Execute data loading operation
load_result = iam_hook.run("""
COPY INTO production.sales_data
FROM 's3://data-lake/sales/{{ ds }}/'
FILEFORMAT = DELTA
COPY_OPTIONS ('mergeSchema' = 'true')
""")Configure connections for high-throughput scenarios:
# High-performance connection configuration
# Extra configuration for optimized connection:
# {
# "http_timeout_seconds": 300,
# "max_connections": 50,
# "connection_pool_size": 10,
# "retry_config": {
# "max_retries": 5,
# "exponential_backoff": true,
# "base_delay": 1,
# "max_delay": 60
# }
# }
from airflow.providers.databricks.hooks.databricks import DatabricksHook
# Hook with optimized retry configuration
optimized_hook = DatabricksHook(
databricks_conn_id='databricks_high_performance',
timeout_seconds=300,
retry_limit=5,
retry_delay=1,
retry_args={
'stop_max_attempt_number': 5,
'wait_exponential_multiplier': 1000,
'wait_exponential_max': 60000
}
)
# Batch job submission with optimized connection
job_runs = []
for job_config in batch_job_configs:
run_id = optimized_hook.submit_run(job_config)
job_runs.append(run_id)
print(f"Submitted {len(job_runs)} jobs successfully")Manage connections across different environments:
from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.models import Variable
def get_environment_hook(environment: str) -> DatabricksHook:
"""Get Databricks hook for specific environment."""
connection_mapping = {
'development': 'databricks_dev',
'staging': 'databricks_staging',
'production': 'databricks_prod'
}
conn_id = connection_mapping.get(environment)
if not conn_id:
raise ValueError(f"Unknown environment: {environment}")
# Environment-specific timeout and retry configuration
timeout_config = {
'development': 1800, # 30 minutes for dev
'staging': 3600, # 1 hour for staging
'production': 7200 # 2 hours for production
}
return DatabricksHook(
databricks_conn_id=conn_id,
timeout_seconds=timeout_config[environment],
retry_limit=3 if environment == 'production' else 1
)
# Usage in DAG
def submit_environment_job(**context):
env = context['params'].get('environment', 'development')
hook = get_environment_hook(env)
job_config = {
'run_name': f'{env}_data_processing',
'notebook_task': {
'notebook_path': f'/Repos/{env}/data-pipeline/main_notebook'
},
'existing_cluster_id': Variable.get(f'{env}_cluster_id')
}
run_id = hook.submit_run(job_config)
return run_idConfigure custom headers for specialized authentication:
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
# SQL hook with custom authentication headers
custom_auth_hook = DatabricksSqlHook(
databricks_conn_id='databricks_custom_auth',
http_path='/sql/1.0/warehouses/custom-warehouse',
http_headers=[
('X-Custom-Auth-Token', 'your-custom-token'),
('X-Request-Source', 'airflow-pipeline'),
('X-Environment', 'production'),
('User-Agent', 'Airflow-Databricks-Provider/1.0')
],
caller='CustomAuthPipeline'
)
# Execute query with custom headers
query_results = custom_auth_hook.run("""
SELECT
table_name,
COUNT(*) as row_count,
MAX(last_modified) as last_update
FROM information_schema.tables
WHERE table_schema = 'analytics'
GROUP BY table_name
""")Implement connection validation and health monitoring:
from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
def validate_databricks_connection(conn_id: str) -> dict[str, Any]:
"""Validate Databricks connection and return health status."""
health_status = {
'connection_id': conn_id,
'rest_api_healthy': False,
'sql_endpoint_healthy': False,
'clusters_accessible': False,
'errors': []
}
try:
# Test REST API connection
rest_hook = DatabricksHook(databricks_conn_id=conn_id)
# Test cluster list access
clusters = rest_hook.list_jobs(limit=1)
health_status['rest_api_healthy'] = True
health_status['clusters_accessible'] = True
except Exception as e:
health_status['errors'].append(f"REST API error: {str(e)}")
try:
# Test SQL endpoint connection
sql_hook = DatabricksSqlHook(databricks_conn_id=conn_id)
# Test simple query
result = sql_hook.run("SELECT 1 as test_connection")
if result:
health_status['sql_endpoint_healthy'] = True
except Exception as e:
health_status['errors'].append(f"SQL endpoint error: {str(e)}")
return health_status
# Use in DAG for connection monitoring
def check_connection_health(**context):
"""Task to check connection health."""
conn_id = context['params'].get('connection_id', 'databricks_default')
health = validate_databricks_connection(conn_id)
if not health['rest_api_healthy']:
raise ValueError(f"REST API connection failed for {conn_id}")
return health
# Connection health check task
health_check = PythonOperator(
task_id='check_databricks_health',
python_callable=check_connection_health,
params={'connection_id': 'databricks_production'}
)Implement connection fallback strategies:
def get_reliable_databricks_hook(primary_conn: str, fallback_conn: str) -> DatabricksHook:
"""Get Databricks hook with automatic fallback."""
try:
# Try primary connection
primary_hook = DatabricksHook(
databricks_conn_id=primary_conn,
timeout_seconds=30 # Quick timeout for testing
)
# Test connection with simple API call
primary_hook.list_jobs(limit=1)
print(f"Using primary connection: {primary_conn}")
# Return hook with normal timeout if successful
return DatabricksHook(
databricks_conn_id=primary_conn,
timeout_seconds=600,
retry_limit=3
)
except Exception as e:
print(f"Primary connection {primary_conn} failed: {str(e)}")
print(f"Falling back to: {fallback_conn}")
return DatabricksHook(
databricks_conn_id=fallback_conn,
timeout_seconds=600,
retry_limit=5 # More retries for fallback
)
# Usage with fallback
def resilient_job_submission(**context):
"""Submit job with connection fallback."""
hook = get_reliable_databricks_hook(
primary_conn='databricks_primary',
fallback_conn='databricks_secondary'
)
job_config = {
'run_name': 'Resilient Job Submission',
'notebook_task': {
'notebook_path': '/Shared/resilient_pipeline'
},
'existing_cluster_id': 'backup-cluster-001'
}
run_id = hook.submit_run(job_config)
return run_idManage SQL session configurations for optimal performance:
def get_optimized_sql_hook(workload_type: str) -> DatabricksSqlHook:
"""Get SQL hook optimized for specific workload types."""
# Workload-specific configurations
configs = {
'etl': {
'spark.sql.adaptive.enabled': 'true',
'spark.sql.adaptive.coalescePartitions.enabled': 'true',
'spark.sql.adaptive.skewJoin.enabled': 'true',
'spark.serializer': 'org.apache.spark.serializer.KryoSerializer'
},
'analytics': {
'spark.sql.execution.arrow.pyspark.enabled': 'true',
'spark.sql.adaptive.enabled': 'true',
'spark.sql.optimizer.dynamicPartitionPruning.enabled': 'true'
},
'ml': {
'spark.sql.execution.arrow.maxRecordsPerBatch': '10000',
'spark.sql.adaptive.enabled': 'true',
'spark.task.maxFailures': '3'
}
}
session_config = configs.get(workload_type, {})
return DatabricksSqlHook(
databricks_conn_id='databricks_sql',
session_configuration=session_config,
caller=f'OptimizedHook-{workload_type}'
)
# ETL workload
etl_hook = get_optimized_sql_hook('etl')
etl_results = etl_hook.run("""
INSERT INTO processed_data
SELECT * FROM raw_data
WHERE processing_date = CURRENT_DATE
""")
# Analytics workload
analytics_hook = get_optimized_sql_hook('analytics')
analytics_df = analytics_hook.get_pandas_df("""
SELECT customer_segment, AVG(order_value)
FROM customer_analytics
GROUP BY customer_segment
""")The connection and authentication system provides robust, flexible access to Databricks services with comprehensive error handling, multiple authentication methods, and performance optimization capabilities.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-databricks