0
# Database Connections and Hooks
1
2
Core connectivity layer providing comprehensive Snowflake database integration through both standard connection protocols and Snowflake's SQL API. These hooks manage authentication, connection pooling, session management, and provide the foundation for all Snowflake operations in Airflow.
3
4
## Capabilities
5
6
### Standard Database Hook
7
8
The primary hook for Snowflake database connections, providing traditional SQL execution, transaction management, and Snowpark session support.
9
10
```python { .api }
11
class SnowflakeHook(DbApiHook):
12
"""
13
Hook for interacting with Snowflake database.
14
Provides connection management, SQL execution, and Snowpark integration.
15
"""
16
17
conn_name_attr = "snowflake_conn_id"
18
default_conn_name = "snowflake_default"
19
conn_type = "snowflake"
20
hook_name = "Snowflake"
21
supports_autocommit = True
22
23
def __init__(
24
self,
25
snowflake_conn_id: str = "snowflake_default",
26
account: str | None = None,
27
authenticator: str | None = None,
28
warehouse: str | None = None,
29
database: str | None = None,
30
region: str | None = None,
31
role: str | None = None,
32
schema: str | None = None,
33
session_parameters: dict | None = None,
34
insecure_mode: bool = False,
35
client_request_mfa_token: bool = False,
36
client_store_temporary_credential: bool = True,
37
*args,
38
**kwargs
39
) -> None:
40
"""
41
Initialize Snowflake hook.
42
43
Parameters:
44
- snowflake_conn_id: Reference to Snowflake connection id
45
- account: Snowflake account name
46
- authenticator: Authentication method ('snowflake', 'externalbrowser', or Okta URL)
47
- warehouse: Name of Snowflake warehouse
48
- database: Name of Snowflake database
49
- region: Name of Snowflake region
50
- role: Name of Snowflake role
51
- schema: Name of Snowflake schema
52
- session_parameters: Session-level parameters
53
- insecure_mode: Turns off OCSP certificate checks
54
- client_request_mfa_token: Request MFA token from client
55
- client_store_temporary_credential: Store temporary credentials on client
56
"""
57
```
58
59
#### Connection Management
60
61
```python { .api }
62
def get_conn(self) -> SnowflakeConnection:
63
"""
64
Return a snowflake.connector.connection object.
65
66
Returns:
67
SnowflakeConnection object configured with connection parameters
68
"""
69
70
def get_uri(self) -> str:
71
"""
72
Override DbApiHook get_uri method for get_sqlalchemy_engine().
73
74
Returns:
75
SQLAlchemy connection URI string
76
"""
77
78
def get_sqlalchemy_engine(self, engine_kwargs=None):
79
"""
80
Get an sqlalchemy_engine object.
81
82
Parameters:
83
- engine_kwargs: Additional arguments for SQLAlchemy engine
84
85
Returns:
86
SQLAlchemy Engine object
87
"""
88
89
def get_snowpark_session(self):
90
"""
91
Get a Snowpark session object for DataFrame operations.
92
93
Returns:
94
Snowpark Session object
95
"""
96
```
97
98
#### Authentication
99
100
```python { .api }
101
def get_oauth_token(
102
self,
103
conn_config: dict | None = None,
104
token_endpoint: str | None = None,
105
grant_type: str = "refresh_token"
106
) -> str:
107
"""
108
Generate temporary OAuth access token for authentication.
109
110
Parameters:
111
- conn_config: Connection configuration override
112
- token_endpoint: OAuth token endpoint URL
113
- grant_type: OAuth grant type
114
115
Returns:
116
OAuth access token string
117
"""
118
```
119
120
#### SQL Execution
121
122
```python { .api }
123
def run(
124
self,
125
sql: str | Iterable[str],
126
autocommit: bool = False,
127
parameters: Iterable | Mapping[str, Any] | None = None,
128
handler: Callable[[Any], T] | None = None,
129
split_statements: bool = True,
130
return_last: bool = True,
131
return_dictionaries: bool = False
132
):
133
"""
134
Run SQL commands against the Snowflake database.
135
136
Parameters:
137
- sql: SQL statement(s) to execute
138
- autocommit: Enable autocommit mode
139
- parameters: Query parameters for parameterized queries
140
- handler: Function to process cursor results
141
- split_statements: Split multiple statements for execution
142
- return_last: Return only last statement result
143
- return_dictionaries: Return results as dictionaries
144
145
Returns:
146
Query results processed by handler or raw cursor results
147
"""
148
149
def set_autocommit(self, conn, autocommit: Any) -> None:
150
"""
151
Set autocommit mode for connection.
152
153
Parameters:
154
- conn: Database connection
155
- autocommit: Boolean autocommit setting
156
"""
157
158
def get_autocommit(self, conn):
159
"""
160
Get current autocommit mode for connection.
161
162
Parameters:
163
- conn: Database connection
164
165
Returns:
166
Current autocommit setting
167
"""
168
```
169
170
#### OpenLineage Integration
171
172
```python { .api }
173
def get_openlineage_database_info(self, connection) -> DatabaseInfo:
174
"""
175
Get database information for OpenLineage integration.
176
177
Parameters:
178
- connection: Database connection object
179
180
Returns:
181
DatabaseInfo object with connection details
182
"""
183
184
def get_openlineage_database_dialect(self, _) -> str:
185
"""
186
Get database dialect for OpenLineage ('snowflake').
187
188
Returns:
189
Database dialect string
190
"""
191
192
def get_openlineage_default_schema(self) -> str | None:
193
"""
194
Get default schema for OpenLineage integration.
195
196
Returns:
197
Default schema name or None
198
"""
199
200
def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
201
"""
202
Get OpenLineage lineage metadata for task execution.
203
204
Parameters:
205
- task_instance: Airflow TaskInstance object
206
207
Returns:
208
OperatorLineage object or None
209
"""
210
```
211
212
#### Properties
213
214
```python { .api }
215
@property
216
def account_identifier(self) -> str:
217
"""
218
Get Snowflake account identifier.
219
220
Returns:
221
Account identifier string
222
"""
223
```
224
225
#### Configuration Methods
226
227
```python { .api }
228
@staticmethod
229
def get_connection_form_widgets() -> dict[str, Any]:
230
"""
231
Return connection widgets for Airflow UI form.
232
233
Returns:
234
Dictionary of form widgets
235
"""
236
237
@staticmethod
238
def get_ui_field_behaviour() -> dict[str, Any]:
239
"""
240
Return custom field behaviour for connection form.
241
242
Returns:
243
Dictionary defining UI field behavior
244
"""
245
```
246
247
### SQL API Hook
248
249
Advanced hook providing Snowflake SQL API integration for submitting multiple SQL statements in single requests, with support for asynchronous execution and JWT authentication.
250
251
```python { .api }
252
class SnowflakeSqlApiHook(SnowflakeHook):
253
"""
254
Hook for interacting with Snowflake using SQL API.
255
Enables submission of multiple SQL statements in a single request.
256
"""
257
258
LIFETIME = timedelta(minutes=59) # JWT Token lifetime
259
RENEWAL_DELTA = timedelta(minutes=54) # JWT Token renewal time
260
261
def __init__(
262
self,
263
snowflake_conn_id: str,
264
token_life_time: timedelta = LIFETIME,
265
token_renewal_delta: timedelta = RENEWAL_DELTA,
266
api_retry_args: dict[Any, Any] | None = None,
267
*args: Any,
268
**kwargs: Any,
269
):
270
"""
271
Initialize SQL API hook.
272
273
Parameters:
274
- snowflake_conn_id: Snowflake connection ID
275
- token_life_time: JWT token lifetime
276
- token_renewal_delta: JWT token renewal interval
277
- api_retry_args: Retry configuration for API calls
278
"""
279
```
280
281
#### Token Management
282
283
```python { .api }
284
def get_private_key(self) -> None:
285
"""
286
Get the private key from Snowflake connection for JWT authentication.
287
"""
288
289
def get_headers(self) -> dict[str, Any]:
290
"""
291
Form authentication headers based on OAuth or JWT token.
292
293
Returns:
294
Dictionary containing HTTP headers for API requests
295
"""
296
```
297
298
#### SQL API Operations
299
300
```python { .api }
301
def execute_query(
302
self,
303
sql: str,
304
statement_count: int,
305
query_tag: str = "",
306
bindings: dict[str, Any] | None = None
307
) -> list[str]:
308
"""
309
Execute query using Snowflake SQL API.
310
311
Parameters:
312
- sql: SQL statements to execute
313
- statement_count: Number of statements in SQL
314
- query_tag: Optional query tag for tracking
315
- bindings: Parameter bindings for queries
316
317
Returns:
318
List of query IDs for submitted statements
319
"""
320
321
def check_query_output(self, query_ids: list[str]) -> None:
322
"""
323
Log query responses for given query IDs.
324
325
Parameters:
326
- query_ids: List of query IDs to check
327
"""
328
```
329
330
#### Query Status and Results
331
332
```python { .api }
333
def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
334
"""
335
Get query status via SQL API.
336
337
Parameters:
338
- query_id: Query ID to check status
339
340
Returns:
341
Dictionary containing query status and metadata
342
"""
343
344
def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
345
"""
346
Get query status asynchronously via SQL API.
347
348
Parameters:
349
- query_id: Query ID to check status
350
351
Returns:
352
Dictionary containing query status and metadata
353
"""
354
355
def wait_for_query(
356
self,
357
query_id: str,
358
raise_error: bool = False,
359
poll_interval: int = 5,
360
timeout: int = 60
361
) -> dict[str, str | list[str]]:
362
"""
363
Wait for query completion with polling.
364
365
Parameters:
366
- query_id: Query ID to wait for
367
- raise_error: Raise exception on query failure
368
- poll_interval: Polling interval in seconds
369
- timeout: Maximum wait time in seconds
370
371
Returns:
372
Final query status dictionary
373
"""
374
375
def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
376
"""
377
Get results from successful query execution.
378
379
Parameters:
380
- query_id: Query ID to retrieve results
381
382
Returns:
383
List of result rows as dictionaries
384
"""
385
```
386
387
#### Request Handling
388
389
```python { .api }
390
def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
391
"""
392
Build request URL and headers for SQL API calls.
393
394
Parameters:
395
- query_id: Query ID for request
396
397
Returns:
398
Tuple of (headers, params, url)
399
"""
400
```
401
402
## Usage Examples
403
404
### Basic SQL Execution
405
406
```python
407
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
408
409
def execute_sql_query():
410
hook = SnowflakeHook(
411
snowflake_conn_id='my_snowflake_conn',
412
warehouse='COMPUTE_WH',
413
database='ANALYTICS',
414
schema='PUBLIC'
415
)
416
417
# Execute single query
418
result = hook.run(
419
sql="SELECT COUNT(*) FROM users WHERE created_date >= '2024-01-01'",
420
handler=lambda cursor: cursor.fetchone()[0]
421
)
422
423
return result
424
425
def execute_multiple_queries():
426
hook = SnowflakeHook(snowflake_conn_id='my_snowflake_conn')
427
428
# Execute multiple statements
429
hook.run([
430
"CREATE TEMP TABLE temp_data AS SELECT * FROM raw_data",
431
"UPDATE temp_data SET status = 'processed' WHERE id > 1000",
432
"INSERT INTO processed_data SELECT * FROM temp_data"
433
])
434
```
435
436
### SQL API Usage
437
438
```python
439
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
440
441
def execute_with_api():
442
hook = SnowflakeSqlApiHook(
443
snowflake_conn_id='my_snowflake_conn',
444
token_life_time=timedelta(minutes=30)
445
)
446
447
# Submit multiple statements via API
448
query_ids = hook.execute_query(
449
sql="""
450
CREATE TABLE IF NOT EXISTS sales_summary AS
451
SELECT region, SUM(amount) as total_sales
452
FROM sales
453
GROUP BY region;
454
455
UPDATE sales_summary
456
SET total_sales = total_sales * 1.1
457
WHERE region = 'WEST';
458
""",
459
statement_count=2,
460
query_tag="monthly_summary"
461
)
462
463
# Wait for completion
464
for query_id in query_ids:
465
status = hook.wait_for_query(query_id, raise_error=True)
466
print(f"Query {query_id} completed with status: {status['status']}")
467
```
468
469
### Snowpark Integration
470
471
```python
472
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
473
474
def snowpark_processing():
475
hook = SnowflakeHook(snowflake_conn_id='my_snowflake_conn')
476
477
# Get Snowpark session
478
session = hook.get_snowpark_session()
479
480
# Use Snowpark DataFrame API
481
df = session.table("raw_sales_data")
482
483
# Transform data
484
processed_df = (df
485
.filter(df.col("amount") > 100)
486
.group_by("region")
487
.agg({"amount": "sum"})
488
.with_column_renamed("SUM(AMOUNT)", "total_sales"))
489
490
# Save results
491
processed_df.write.save_as_table("regional_sales", mode="overwrite")
492
493
return processed_df.count()
494
```
495
496
## Error Handling
497
498
Both hooks provide comprehensive error handling:
499
500
- **Connection Errors**: Authentication failures, network timeouts, invalid credentials
501
- **SQL Errors**: Syntax errors, permission issues, resource constraints
502
- **API Errors**: Invalid query IDs, malformed requests, rate limiting
503
- **Token Errors**: JWT expiration, key validation failures
504
505
All exceptions include detailed Snowflake error codes and descriptive messages for troubleshooting.