0
# Utility Functions
1
2
Helper functions for parameter formatting, authentication token management, OpenLineage integration, and Snowpark session injection. These utilities provide essential support functionality for secure and efficient Snowflake operations within Airflow workflows.
3
4
## Capabilities
5
6
### Parameter Safety
7
8
Security utility for safely formatting SQL parameters to prevent injection attacks and ensure proper parameter handling in Snowflake queries.
9
10
```python { .api }
11
def enclose_param(param: str) -> str:
12
"""
13
Replace all single quotes in parameter by two single quotes and enclose param in single quotes.
14
Provides SQL injection protection for Snowflake string parameters.
15
16
Parameters:
17
- param: String parameter to be safely enclosed
18
19
Returns:
20
Safely enclosed parameter string suitable for SQL inclusion
21
22
Examples:
23
- enclose_param("without quotes") returns "'without quotes'"
24
- enclose_param("'with quotes'") returns "'''with quotes'''"
25
- enclose_param("Today's sales") returns "'Today''s sales'"
26
"""
27
```
28
29
### Snowpark Session Management
30
31
Utility function for intelligent session injection based on function signatures, enabling automatic Snowpark session management in callable functions.
32
33
```python { .api }
34
def inject_session_into_op_kwargs(
35
python_callable: Callable,
36
op_kwargs: dict,
37
session: Session | None
38
) -> dict:
39
"""
40
Inject Snowpark session into operator kwargs based on callable signature.
41
Automatically provides session parameter if the callable function defines it.
42
43
Parameters:
44
- python_callable: Python function to analyze for session parameter
45
- op_kwargs: Current operator keyword arguments
46
- session: Snowpark Session object to inject
47
48
Returns:
49
Updated kwargs dictionary with session injected if needed
50
51
Behavior:
52
- Inspects callable signature for 'session' parameter
53
- Injects session only if parameter exists in signature
54
- Preserves existing kwargs without modification
55
"""
56
```
57
58
### JWT Token Generation
59
60
Comprehensive JWT token management for Snowflake SQL API authentication with automatic renewal and secure key handling.
61
62
```python { .api }
63
class JWTGenerator:
64
"""
65
Creates and signs JWTs for Snowflake SQL API authentication.
66
Provides automatic token renewal and secure key management.
67
"""
68
69
LIFETIME = timedelta(minutes=59) # Default token lifetime
70
RENEWAL_DELTA = timedelta(minutes=54) # Default token renewal delay
71
ALGORITHM = "RS256" # JWT signing algorithm
72
73
def __init__(
74
self,
75
account: str,
76
user: str,
77
private_key: Any,
78
lifetime: timedelta = LIFETIME,
79
renewal_delay: timedelta = RENEWAL_DELTA,
80
):
81
"""
82
Initialize JWT generator for Snowflake authentication.
83
84
Parameters:
85
- account: Snowflake account identifier
86
- user: Snowflake username
87
- private_key: RSA private key for JWT signing
88
- lifetime: JWT token lifetime duration
89
- renewal_delay: Time before expiration to renew token
90
"""
91
```
92
93
#### Token Management Methods
94
95
```python { .api }
96
def get_token(self) -> str | None:
97
"""
98
Generate new JWT token or return cached valid token.
99
Automatically handles token renewal based on expiration timing.
100
101
Returns:
102
Valid JWT token string or None if generation fails
103
104
Behavior:
105
- Returns cached token if still valid
106
- Generates new token if expired or near expiration
107
- Handles automatic renewal based on renewal_delay
108
"""
109
110
@staticmethod
111
def prepare_account_name_for_jwt(raw_account: str) -> str:
112
"""
113
Prepare account identifier for JWT token generation.
114
Formats account name according to Snowflake requirements.
115
116
Parameters:
117
- raw_account: Raw account identifier string
118
119
Returns:
120
Properly formatted account identifier for JWT claims
121
"""
122
123
@staticmethod
124
def calculate_public_key_fingerprint(private_key: Any) -> str:
125
"""
126
Calculate public key fingerprint from private key.
127
Used for JWT key identification in Snowflake authentication.
128
129
Parameters:
130
- private_key: RSA private key object
131
132
Returns:
133
SHA256 fingerprint of the public key in required format
134
"""
135
```
136
137
#### JWT Constants
138
139
```python { .api }
140
# JWT claim field names
141
ISSUER = "iss" # JWT issuer field name
142
EXPIRE_TIME = "exp" # JWT expiration field name
143
ISSUE_TIME = "iat" # JWT issue time field name
144
SUBJECT = "sub" # JWT subject field name
145
```
146
147
### OpenLineage Integration
148
149
Comprehensive OpenLineage integration utilities for data lineage tracking, providing metadata extraction and event generation for Snowflake operations.
150
151
#### Account and URI Utilities
152
153
```python { .api }
154
def fix_account_name(name: str) -> str:
155
"""
156
Fix account name to proper OpenLineage format.
157
Converts account identifier to standard format: <account_id>.<region>.<cloud>
158
159
Parameters:
160
- name: Raw Snowflake account identifier
161
162
Returns:
163
Properly formatted account name for OpenLineage
164
"""
165
166
def fix_snowflake_sqlalchemy_uri(uri: str) -> str:
167
"""
168
Fix Snowflake SQLAlchemy URI to OpenLineage structure.
169
Standardizes connection URI format for lineage tracking.
170
171
Parameters:
172
- uri: Raw SQLAlchemy connection URI
173
174
Returns:
175
Standardized URI string for OpenLineage integration
176
"""
177
```
178
179
#### Event Generation
180
181
```python { .api }
182
def emit_openlineage_events_for_snowflake_queries(
183
hook: Union[SnowflakeHook, SnowflakeSqlApiHook],
184
task_instance: TaskInstance,
185
query_ids: list[str],
186
sql: str
187
) -> None:
188
"""
189
Emit OpenLineage events for executed Snowflake queries.
190
Generates comprehensive lineage metadata for data operations.
191
192
Parameters:
193
- hook: Snowflake hook instance (regular or SQL API)
194
- task_instance: Airflow TaskInstance for context
195
- query_ids: List of Snowflake query IDs to track
196
- sql: SQL statements that were executed
197
198
Behavior:
199
- Extracts table and column lineage from query results
200
- Generates start and complete OpenLineage events
201
- Includes comprehensive metadata about data transformations
202
"""
203
```
204
205
### Version Compatibility
206
207
Utilities for handling Airflow version compatibility and ensuring consistent behavior across different Airflow versions.
208
209
```python { .api }
210
# Version detection constants
211
AIRFLOW_V_3_0_PLUS: bool # Whether running Airflow 3.0 or later
212
213
# Compatible base operator import
214
BaseOperator # Imports from airflow.sdk for 3.0+, airflow.models for <3.0
215
216
def get_base_airflow_version_tuple() -> tuple[int, int, int]:
217
"""
218
Get the base Airflow version as a tuple of (major, minor, micro).
219
Provides version information for compatibility checks.
220
221
Returns:
222
Tuple of (major, minor, micro) version numbers
223
224
Example:
225
- For Airflow 2.10.1 returns (2, 10, 1)
226
- For Airflow 3.0.0 returns (3, 0, 0)
227
"""
228
```
229
230
## Usage Examples
231
232
### Safe Parameter Handling
233
234
```python
235
from airflow.providers.snowflake.utils.common import enclose_param
236
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
237
238
def safe_query_execution(customer_name: str, region: str):
239
"""Execute query with safely enclosed parameters."""
240
241
hook = SnowflakeHook(snowflake_conn_id='snowflake_prod')
242
243
# Safely enclose user-provided parameters
244
safe_customer = enclose_param(customer_name)
245
safe_region = enclose_param(region)
246
247
# Build secure SQL query
248
sql = f"""
249
SELECT *
250
FROM customers
251
WHERE customer_name = {safe_customer}
252
AND region = {safe_region}
253
"""
254
255
result = hook.run(sql, handler=lambda cursor: cursor.fetchall())
256
return result
257
258
# Example usage
259
customers = safe_query_execution("O'Reilly Industries", "North America")
260
```
261
262
### Snowpark Session Injection
263
264
```python
265
from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs
266
from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
267
268
def my_data_processing_function(data_param: str, session=None, **kwargs):
269
"""
270
Data processing function that optionally uses Snowpark session.
271
Session will be automatically injected if available.
272
"""
273
if session:
274
# Use Snowpark DataFrame API
275
df = session.table("raw_data")
276
return df.filter(df.col("category") == data_param).count()
277
else:
278
# Fallback to regular processing
279
print(f"Processing {data_param} without Snowpark")
280
return 0
281
282
def create_dynamic_snowpark_task(python_callable, **op_kwargs):
283
"""Create Snowpark task with intelligent session injection."""
284
285
# This would be called internally by SnowparkOperator
286
session = get_snowpark_session() # Hypothetical session creation
287
288
# Inject session only if function signature requires it
289
enhanced_kwargs = inject_session_into_op_kwargs(
290
python_callable=python_callable,
291
op_kwargs=op_kwargs,
292
session=session
293
)
294
295
return python_callable(**enhanced_kwargs)
296
```
297
298
### JWT Token Management
299
300
```python
301
from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
302
from datetime import timedelta
303
import os
304
305
def setup_jwt_authentication():
306
"""Set up JWT authentication for Snowflake SQL API."""
307
308
# Load private key from secure storage
309
private_key_path = os.getenv('SNOWFLAKE_PRIVATE_KEY_PATH')
310
with open(private_key_path, 'rb') as key_file:
311
private_key = key_file.read()
312
313
# Initialize JWT generator with custom lifetime
314
jwt_gen = JWTGenerator(
315
account='mycompany.us-east-1',
316
user='airflow_service',
317
private_key=private_key,
318
lifetime=timedelta(hours=1), # 1-hour tokens
319
renewal_delay=timedelta(minutes=50) # Renew 10 minutes before expiry
320
)
321
322
# Get token (will be cached and automatically renewed)
323
token = jwt_gen.get_token()
324
325
return token
326
327
def make_authenticated_api_call(query: str):
328
"""Make Snowflake SQL API call with JWT authentication."""
329
330
token = setup_jwt_authentication()
331
332
headers = {
333
'Authorization': f'Bearer {token}',
334
'Content-Type': 'application/json',
335
'X-Snowflake-Authorization-Token-Type': 'KEYPAIR_JWT'
336
}
337
338
# Make API call with authenticated headers
339
# ... API call implementation
340
```
341
342
### OpenLineage Event Generation
343
344
```python
345
from airflow.providers.snowflake.utils.openlineage import (
346
emit_openlineage_events_for_snowflake_queries,
347
fix_account_name,
348
fix_snowflake_sqlalchemy_uri
349
)
350
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
351
352
def execute_with_lineage_tracking(task_instance, sql_statements: str):
353
"""Execute SQL with comprehensive lineage tracking."""
354
355
# Initialize hook with proper account format
356
hook = SnowflakeSqlApiHook(snowflake_conn_id='snowflake_prod')
357
358
# Execute queries and get query IDs
359
query_ids = hook.execute_query(
360
sql=sql_statements,
361
statement_count=2,
362
query_tag='lineage_tracked_operation'
363
)
364
365
# Wait for completion
366
for query_id in query_ids:
367
hook.wait_for_query(query_id, raise_error=True)
368
369
# Emit OpenLineage events for data lineage
370
emit_openlineage_events_for_snowflake_queries(
371
hook=hook,
372
task_instance=task_instance,
373
query_ids=query_ids,
374
sql=sql_statements
375
)
376
377
return query_ids
378
379
# Usage in custom operator
380
class LineageAwareSnowflakeOperator(SnowflakeSqlApiOperator):
381
"""Custom operator with automatic lineage tracking."""
382
383
def execute(self, context):
384
# Execute queries
385
result = super().execute(context)
386
387
# Add lineage tracking
388
if hasattr(self, '_query_ids') and self._query_ids:
389
emit_openlineage_events_for_snowflake_queries(
390
hook=self.hook,
391
task_instance=context['task_instance'],
392
query_ids=self._query_ids,
393
sql=self.sql
394
)
395
396
return result
397
```
398
399
### Version Compatibility Handling
400
401
```python
402
from airflow.providers.snowflake.version_compat import (
403
AIRFLOW_V_3_0_PLUS,
404
BaseOperator,
405
get_base_airflow_version_tuple
406
)
407
408
def create_compatible_operator():
409
"""Create operator compatible with multiple Airflow versions."""
410
411
version_tuple = get_base_airflow_version_tuple()
412
413
if AIRFLOW_V_3_0_PLUS:
414
# Use Airflow 3.0+ features
415
print(f"Running on Airflow {version_tuple}, using SDK features")
416
417
# BaseOperator imported from airflow.sdk
418
class ModernSnowflakeOperator(BaseOperator):
419
def execute(self, context):
420
# Modern execution logic
421
pass
422
else:
423
# Use legacy Airflow features
424
print(f"Running on Airflow {version_tuple}, using legacy features")
425
426
# BaseOperator imported from airflow.models
427
class LegacySnowflakeOperator(BaseOperator):
428
def execute(self, context):
429
# Legacy execution logic
430
pass
431
432
return ModernSnowflakeOperator if AIRFLOW_V_3_0_PLUS else LegacySnowflakeOperator
433
```
434
435
### Combined Utility Usage
436
437
```python
438
from airflow import DAG
439
from airflow.providers.snowflake.operators.snowflake import SnowflakeSqlApiOperator
440
from airflow.providers.snowflake.utils.common import enclose_param
441
from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator
442
from datetime import datetime, timedelta
443
444
def create_secure_analytics_pipeline():
445
"""Create analytics pipeline with comprehensive security and monitoring."""
446
447
with DAG(
448
'secure_analytics_pipeline',
449
start_date=datetime(2024, 1, 1),
450
schedule_interval='@daily',
451
catchup=False
452
) as dag:
453
454
def build_secure_query(region: str, category: str) -> str:
455
"""Build SQL query with safe parameter handling."""
456
457
# Safely encode parameters
458
safe_region = enclose_param(region)
459
safe_category = enclose_param(category)
460
461
return f"""
462
CREATE OR REPLACE TABLE analytics.regional_analysis AS
463
SELECT
464
region,
465
product_category,
466
COUNT(*) as transaction_count,
467
SUM(amount) as total_revenue,
468
AVG(amount) as avg_transaction
469
FROM raw.transactions
470
WHERE region = {safe_region}
471
AND product_category = {safe_category}
472
AND transaction_date = '{{{{ ds }}}}'
473
GROUP BY region, product_category;
474
475
-- Create summary statistics
476
INSERT INTO analytics.daily_summary
477
SELECT
478
'{{{{ ds }}}}' as summary_date,
479
{safe_region} as region,
480
{safe_category} as category,
481
SUM(transaction_count) as total_transactions,
482
SUM(total_revenue) as total_revenue
483
FROM analytics.regional_analysis;
484
"""
485
486
# Dynamic task creation with secure parameters
487
regions = ['North', 'South', 'East', 'West']
488
categories = ['Electronics', 'Clothing', 'Home & Garden']
489
490
tasks = []
491
for region in regions:
492
for category in categories:
493
task = SnowflakeSqlApiOperator(
494
task_id=f'analyze_{region}_{category}'.lower().replace(' ', '_').replace('&', 'and'),
495
snowflake_conn_id='snowflake_prod',
496
sql=build_secure_query(region, category),
497
statement_count=2,
498
deferrable=True,
499
poll_interval=20,
500
warehouse='ANALYTICS_WH',
501
# Custom JWT configuration
502
token_life_time=timedelta(hours=2),
503
token_renewal_delta=timedelta(minutes=90),
504
session_parameters={
505
'QUERY_TAG': f'secure_analytics_{region}_{category}',
506
'AUTOCOMMIT': True
507
}
508
)
509
tasks.append(task)
510
511
return tasks
512
513
# Create pipeline with all security and monitoring features
514
pipeline_tasks = create_secure_analytics_pipeline()
515
```
516
517
## Security Best Practices
518
519
### Parameter Sanitization
520
- Always use `enclose_param()` for user-provided strings
521
- Validate input parameters before processing
522
- Use parameterized queries when possible
523
524
### Token Management
525
- Store private keys securely (environment variables, secret managers)
526
- Use appropriate token lifetimes for operation duration
527
- Monitor token usage and renewal patterns
528
529
### Connection Security
530
- Use role-based access control in Snowflake
531
- Implement connection parameter validation
532
- Monitor authentication failures and suspicious activity
533
534
## Performance Optimization
535
536
### JWT Token Caching
537
- Tokens are automatically cached and reused until near expiration
538
- Configure renewal timing based on typical operation duration
539
- Monitor token generation overhead in high-frequency scenarios
540
541
### Parameter Processing
542
- Parameter enclosing is lightweight but accumulates in high-volume scenarios
543
- Consider batch processing for multiple parameter values
544
- Cache commonly used parameter values when appropriate
545
546
## Error Handling
547
548
All utility functions provide comprehensive error handling:
549
550
- **Parameter Errors**: Invalid characters, encoding issues, null values
551
- **Authentication Errors**: Key format issues, token generation failures, expired credentials
552
- **OpenLineage Errors**: Metadata extraction failures, event emission issues
553
- **Version Compatibility**: Unsupported features, import failures, version detection issues
554
555
Error messages include specific context and suggestions for resolution.