Provider package apache-airflow-providers-snowflake for Apache Airflow
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Core connectivity layer providing comprehensive Snowflake database integration through both standard connection protocols and Snowflake's SQL API. These hooks manage authentication, connection pooling, session management, and provide the foundation for all Snowflake operations in Airflow.
The primary hook for Snowflake database connections, providing traditional SQL execution, transaction management, and Snowpark session support.
class SnowflakeHook(DbApiHook):
"""
Hook for interacting with Snowflake database.
Provides connection management, SQL execution, and Snowpark integration.
"""
conn_name_attr = "snowflake_conn_id"
default_conn_name = "snowflake_default"
conn_type = "snowflake"
hook_name = "Snowflake"
supports_autocommit = True
def __init__(
self,
snowflake_conn_id: str = "snowflake_default",
account: str | None = None,
authenticator: str | None = None,
warehouse: str | None = None,
database: str | None = None,
region: str | None = None,
role: str | None = None,
schema: str | None = None,
session_parameters: dict | None = None,
insecure_mode: bool = False,
client_request_mfa_token: bool = False,
client_store_temporary_credential: bool = True,
*args,
**kwargs
) -> None:
"""
Initialize Snowflake hook.
Parameters:
- snowflake_conn_id: Reference to Snowflake connection id
- account: Snowflake account name
- authenticator: Authentication method ('snowflake', 'externalbrowser', or Okta URL)
- warehouse: Name of Snowflake warehouse
- database: Name of Snowflake database
- region: Name of Snowflake region
- role: Name of Snowflake role
- schema: Name of Snowflake schema
- session_parameters: Session-level parameters
- insecure_mode: Turns off OCSP certificate checks
- client_request_mfa_token: Request MFA token from client
- client_store_temporary_credential: Store temporary credentials on client
"""def get_conn(self) -> SnowflakeConnection:
"""
Return a snowflake.connector.connection object.
Returns:
SnowflakeConnection object configured with connection parameters
"""
def get_uri(self) -> str:
"""
Override DbApiHook get_uri method for get_sqlalchemy_engine().
Returns:
SQLAlchemy connection URI string
"""
def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Get an sqlalchemy_engine object.
Parameters:
- engine_kwargs: Additional arguments for SQLAlchemy engine
Returns:
SQLAlchemy Engine object
"""
def get_snowpark_session(self):
"""
Get a Snowpark session object for DataFrame operations.
Returns:
Snowpark Session object
"""def get_oauth_token(
self,
conn_config: dict | None = None,
token_endpoint: str | None = None,
grant_type: str = "refresh_token"
) -> str:
"""
Generate temporary OAuth access token for authentication.
Parameters:
- conn_config: Connection configuration override
- token_endpoint: OAuth token endpoint URL
- grant_type: OAuth grant type
Returns:
OAuth access token string
"""def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
return_dictionaries: bool = False
):
"""
Run SQL commands against the Snowflake database.
Parameters:
- sql: SQL statement(s) to execute
- autocommit: Enable autocommit mode
- parameters: Query parameters for parameterized queries
- handler: Function to process cursor results
- split_statements: Split multiple statements for execution
- return_last: Return only last statement result
- return_dictionaries: Return results as dictionaries
Returns:
Query results processed by handler or raw cursor results
"""
def set_autocommit(self, conn, autocommit: Any) -> None:
"""
Set autocommit mode for connection.
Parameters:
- conn: Database connection
- autocommit: Boolean autocommit setting
"""
def get_autocommit(self, conn):
"""
Get current autocommit mode for connection.
Parameters:
- conn: Database connection
Returns:
Current autocommit setting
"""def get_openlineage_database_info(self, connection) -> DatabaseInfo:
"""
Get database information for OpenLineage integration.
Parameters:
- connection: Database connection object
Returns:
DatabaseInfo object with connection details
"""
def get_openlineage_database_dialect(self, _) -> str:
"""
Get database dialect for OpenLineage ('snowflake').
Returns:
Database dialect string
"""
def get_openlineage_default_schema(self) -> str | None:
"""
Get default schema for OpenLineage integration.
Returns:
Default schema name or None
"""
def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
"""
Get OpenLineage lineage metadata for task execution.
Parameters:
- task_instance: Airflow TaskInstance object
Returns:
OperatorLineage object or None
"""@property
def account_identifier(self) -> str:
"""
Get Snowflake account identifier.
Returns:
Account identifier string
"""@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""
Return connection widgets for Airflow UI form.
Returns:
Dictionary of form widgets
"""
@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""
Return custom field behaviour for connection form.
Returns:
Dictionary defining UI field behavior
"""Advanced hook providing Snowflake SQL API integration for submitting multiple SQL statements in single requests, with support for asynchronous execution and JWT authentication.
class SnowflakeSqlApiHook(SnowflakeHook):
"""
Hook for interacting with Snowflake using SQL API.
Enables submission of multiple SQL statements in a single request.
"""
LIFETIME = timedelta(minutes=59) # JWT Token lifetime
RENEWAL_DELTA = timedelta(minutes=54) # JWT Token renewal time
def __init__(
self,
snowflake_conn_id: str,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
api_retry_args: dict[Any, Any] | None = None,
*args: Any,
**kwargs: Any,
):
"""
Initialize SQL API hook.
Parameters:
- snowflake_conn_id: Snowflake connection ID
- token_life_time: JWT token lifetime
- token_renewal_delta: JWT token renewal interval
- api_retry_args: Retry configuration for API calls
"""def get_private_key(self) -> None:
"""
Get the private key from Snowflake connection for JWT authentication.
"""
def get_headers(self) -> dict[str, Any]:
"""
Form authentication headers based on OAuth or JWT token.
Returns:
Dictionary containing HTTP headers for API requests
"""def execute_query(
self,
sql: str,
statement_count: int,
query_tag: str = "",
bindings: dict[str, Any] | None = None
) -> list[str]:
"""
Execute query using Snowflake SQL API.
Parameters:
- sql: SQL statements to execute
- statement_count: Number of statements in SQL
- query_tag: Optional query tag for tracking
- bindings: Parameter bindings for queries
Returns:
List of query IDs for submitted statements
"""
def check_query_output(self, query_ids: list[str]) -> None:
"""
Log query responses for given query IDs.
Parameters:
- query_ids: List of query IDs to check
"""def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
"""
Get query status via SQL API.
Parameters:
- query_id: Query ID to check status
Returns:
Dictionary containing query status and metadata
"""
def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
"""
Get query status asynchronously via SQL API.
Parameters:
- query_id: Query ID to check status
Returns:
Dictionary containing query status and metadata
"""
def wait_for_query(
self,
query_id: str,
raise_error: bool = False,
poll_interval: int = 5,
timeout: int = 60
) -> dict[str, str | list[str]]:
"""
Wait for query completion with polling.
Parameters:
- query_id: Query ID to wait for
- raise_error: Raise exception on query failure
- poll_interval: Polling interval in seconds
- timeout: Maximum wait time in seconds
Returns:
Final query status dictionary
"""
def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
"""
Get results from successful query execution.
Parameters:
- query_id: Query ID to retrieve results
Returns:
List of result rows as dictionaries
"""def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Build request URL and headers for SQL API calls.
Parameters:
- query_id: Query ID for request
Returns:
Tuple of (headers, params, url)
"""from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
def execute_sql_query():
hook = SnowflakeHook(
snowflake_conn_id='my_snowflake_conn',
warehouse='COMPUTE_WH',
database='ANALYTICS',
schema='PUBLIC'
)
# Execute single query
result = hook.run(
sql="SELECT COUNT(*) FROM users WHERE created_date >= '2024-01-01'",
handler=lambda cursor: cursor.fetchone()[0]
)
return result
def execute_multiple_queries():
hook = SnowflakeHook(snowflake_conn_id='my_snowflake_conn')
# Execute multiple statements
hook.run([
"CREATE TEMP TABLE temp_data AS SELECT * FROM raw_data",
"UPDATE temp_data SET status = 'processed' WHERE id > 1000",
"INSERT INTO processed_data SELECT * FROM temp_data"
])from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
def execute_with_api():
hook = SnowflakeSqlApiHook(
snowflake_conn_id='my_snowflake_conn',
token_life_time=timedelta(minutes=30)
)
# Submit multiple statements via API
query_ids = hook.execute_query(
sql="""
CREATE TABLE IF NOT EXISTS sales_summary AS
SELECT region, SUM(amount) as total_sales
FROM sales
GROUP BY region;
UPDATE sales_summary
SET total_sales = total_sales * 1.1
WHERE region = 'WEST';
""",
statement_count=2,
query_tag="monthly_summary"
)
# Wait for completion
for query_id in query_ids:
status = hook.wait_for_query(query_id, raise_error=True)
print(f"Query {query_id} completed with status: {status['status']}")from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
def snowpark_processing():
hook = SnowflakeHook(snowflake_conn_id='my_snowflake_conn')
# Get Snowpark session
session = hook.get_snowpark_session()
# Use Snowpark DataFrame API
df = session.table("raw_sales_data")
# Transform data
processed_df = (df
.filter(df.col("amount") > 100)
.group_by("region")
.agg({"amount": "sum"})
.with_column_renamed("SUM(AMOUNT)", "total_sales"))
# Save results
processed_df.write.save_as_table("regional_sales", mode="overwrite")
return processed_df.count()Both hooks provide comprehensive error handling:
All exceptions include detailed Snowflake error codes and descriptive messages for troubleshooting.
Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-snowflake