0
# Spark Integration
1
2
Integration with Athena's Spark execution engine for distributed processing, Jupyter notebook compatibility, and advanced analytics workloads. Enables running Spark code directly on Athena's serverless Spark engine.
3
4
## Capabilities
5
6
### Spark Cursor
7
8
Cursor for executing Spark calculations on Athena's Spark engine, providing distributed processing capabilities and integration with Jupyter notebooks.
9
10
```python { .api }
11
class SparkCursor:
12
session_id: str
13
calculation_id: Optional[str]
14
description: Optional[str]
15
working_directory: Optional[str]
16
state: Optional[str]
17
state_change_reason: Optional[str]
18
submission_date_time: Optional[datetime]
19
completion_date_time: Optional[datetime]
20
dpu_execution_in_millis: Optional[int]
21
progress: Optional[str]
22
std_out_s3_uri: Optional[str]
23
std_error_s3_uri: Optional[str]
24
result_s3_uri: Optional[str]
25
result_type: Optional[str]
26
27
def execute(self, code: str, **kwargs) -> SparkCursor:
28
"""
29
Execute Spark code on Athena's Spark engine.
30
31
Parameters:
32
- code: Spark code to execute (Python, Scala, or SQL)
33
- **kwargs: Additional execution options
34
35
Returns:
36
Self for method chaining
37
"""
38
39
def cancel(self) -> None:
40
"""Cancel the currently executing Spark calculation."""
41
42
def close(self) -> None:
43
"""Close the cursor and clean up resources."""
44
45
@property
46
def calculation_execution(self) -> Optional[AthenaCalculationExecution]:
47
"""Get the current calculation execution metadata."""
48
```
49
50
### Async Spark Cursor
51
52
Asynchronous version of SparkCursor for non-blocking Spark calculations.
53
54
```python { .api }
55
class AsyncSparkCursor:
56
def execute(self, code: str, **kwargs) -> Tuple[str, Future[AthenaCalculationExecution]]:
57
"""
58
Execute Spark code asynchronously.
59
60
Parameters:
61
- code: Spark code to execute
62
63
Returns:
64
Tuple of (calculation_id, Future[AthenaCalculationExecution])
65
"""
66
67
def cancel(self, calculation_id: str) -> Future[None]:
68
"""Cancel Spark calculation by ID asynchronously."""
69
70
def close(self, wait: bool = False) -> None:
71
"""Close cursor, optionally waiting for running calculations."""
72
```
73
74
### Spark Calculation Models
75
76
Models representing Spark calculation execution and status information.
77
78
```python { .api }
79
class AthenaCalculationExecution:
80
# Calculation states
81
STATE_CREATING: str = "CREATING"
82
STATE_CREATED: str = "CREATED"
83
STATE_QUEUED: str = "QUEUED"
84
STATE_RUNNING: str = "RUNNING"
85
STATE_CANCELLING: str = "CANCELLING"
86
STATE_CANCELLED: str = "CANCELLED"
87
STATE_COMPLETED: str = "COMPLETED"
88
STATE_FAILED: str = "FAILED"
89
90
calculation_execution_id: Optional[str]
91
session_id: Optional[str]
92
description: Optional[str]
93
working_directory: Optional[str]
94
state: Optional[str]
95
state_change_reason: Optional[str]
96
submission_date_time: Optional[datetime]
97
completion_date_time: Optional[datetime]
98
dpu_execution_in_millis: Optional[int]
99
progress: Optional[str]
100
std_out_s3_uri: Optional[str]
101
std_error_s3_uri: Optional[str]
102
result_s3_uri: Optional[str]
103
result_type: Optional[str]
104
105
class AthenaSessionStatus:
106
session_id: Optional[str]
107
description: Optional[str]
108
working_directory: Optional[str]
109
idle_since_date_time: Optional[datetime]
110
last_modified_date_time: Optional[datetime]
111
termination_date_time: Optional[datetime]
112
notebook_version: Optional[str]
113
session_configuration: Optional[Dict]
114
status: Optional[Dict]
115
```
116
117
## Usage Examples
118
119
### Basic Spark Code Execution
120
121
```python
122
from pyathena import connect
123
from pyathena.spark.cursor import SparkCursor
124
125
# Connect with Spark cursor
126
conn = connect(
127
s3_staging_dir="s3://my-bucket/athena-results/",
128
region_name="us-west-2",
129
cursor_class=SparkCursor,
130
work_group="spark-workgroup" # Must be configured for Spark
131
)
132
133
cursor = conn.cursor()
134
135
# Execute Spark Python code
136
spark_code = """
137
from pyspark.sql import SparkSession
138
139
spark = SparkSession.builder.appName("AthenaSparkExample").getOrCreate()
140
141
# Read data from S3
142
df = spark.read.parquet("s3://my-bucket/data/sales/")
143
144
# Perform transformations
145
result = df.groupBy("product_category") \\
146
.agg({"amount": "sum", "quantity": "count"}) \\
147
.orderBy("sum(amount)", ascending=False)
148
149
result.show()
150
151
# Save results
152
result.write.mode("overwrite").parquet("s3://my-bucket/results/category_summary/")
153
"""
154
155
cursor.execute(spark_code)
156
157
# Check execution status
158
print(f"Session ID: {cursor.session_id}")
159
print(f"Calculation ID: {cursor.calculation_id}")
160
print(f"State: {cursor.state}")
161
print(f"Progress: {cursor.progress}")
162
163
# Get output and error logs
164
if cursor.std_out_s3_uri:
165
print(f"Output logs: {cursor.std_out_s3_uri}")
166
if cursor.std_error_s3_uri:
167
print(f"Error logs: {cursor.std_error_s3_uri}")
168
169
cursor.close()
170
conn.close()
171
```
172
173
### Spark SQL Execution
174
175
```python
176
from pyathena import connect
177
from pyathena.spark.cursor import SparkCursor
178
179
conn = connect(
180
s3_staging_dir="s3://my-bucket/athena-results/",
181
region_name="us-west-2",
182
cursor_class=SparkCursor,
183
work_group="spark-workgroup"
184
)
185
186
cursor = conn.cursor()
187
188
# Execute Spark SQL
189
spark_sql = """
190
-- Create temporary view from S3 data
191
CREATE OR REPLACE TEMPORARY VIEW sales
192
USING PARQUET
193
OPTIONS (
194
path "s3://my-bucket/data/sales/"
195
);
196
197
-- Perform complex analytics
198
WITH monthly_metrics AS (
199
SELECT
200
DATE_FORMAT(sale_date, 'yyyy-MM') as sale_month,
201
product_category,
202
SUM(amount) as total_revenue,
203
COUNT(*) as transaction_count,
204
COUNT(DISTINCT customer_id) as unique_customers
205
FROM sales
206
WHERE sale_date >= '2023-01-01'
207
GROUP BY DATE_FORMAT(sale_date, 'yyyy-MM'), product_category
208
),
209
category_trends AS (
210
SELECT
211
product_category,
212
sale_month,
213
total_revenue,
214
LAG(total_revenue) OVER (
215
PARTITION BY product_category
216
ORDER BY sale_month
217
) as prev_month_revenue,
218
total_revenue - LAG(total_revenue) OVER (
219
PARTITION BY product_category
220
ORDER BY sale_month
221
) as revenue_change
222
FROM monthly_metrics
223
)
224
SELECT
225
product_category,
226
sale_month,
227
total_revenue,
228
revenue_change,
229
CASE
230
WHEN revenue_change > 0 THEN 'Growth'
231
WHEN revenue_change < 0 THEN 'Decline'
232
ELSE 'Stable'
233
END as trend
234
FROM category_trends
235
WHERE prev_month_revenue IS NOT NULL
236
ORDER BY product_category, sale_month;
237
"""
238
239
cursor.execute(spark_sql)
240
241
print(f"Spark SQL execution started")
242
print(f"Calculation ID: {cursor.calculation_id}")
243
print(f"Working directory: {cursor.working_directory}")
244
245
cursor.close()
246
conn.close()
247
```
248
249
### Advanced Spark Analytics
250
251
```python
252
from pyathena import connect
253
from pyathena.spark.cursor import SparkCursor
254
255
def spark_ml_pipeline():
256
conn = connect(
257
s3_staging_dir="s3://my-bucket/athena-results/",
258
region_name="us-west-2",
259
cursor_class=SparkCursor,
260
work_group="spark-workgroup"
261
)
262
263
cursor = conn.cursor()
264
265
# Machine learning pipeline with Spark MLlib
266
ml_code = """
267
from pyspark.sql import SparkSession
268
from pyspark.ml.feature import VectorAssembler, StandardScaler
269
from pyspark.ml.clustering import KMeans
270
from pyspark.ml.evaluation import ClusteringEvaluator
271
from pyspark.sql.functions import col, when, isnan, count
272
273
spark = SparkSession.builder.appName("CustomerSegmentation").getOrCreate()
274
275
# Load customer data
276
customers = spark.read.parquet("s3://my-bucket/data/customers/")
277
278
# Data preprocessing
279
print("Original dataset shape:", customers.count(), len(customers.columns))
280
281
# Handle missing values
282
customers_clean = customers.dropna()
283
print("After removing nulls:", customers_clean.count())
284
285
# Feature engineering
286
feature_cols = ["age", "annual_income", "spending_score", "total_orders", "avg_order_value"]
287
288
# Create feature vector
289
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
290
customers_features = assembler.transform(customers_clean)
291
292
# Scale features
293
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
294
scaler_model = scaler.fit(customers_features)
295
customers_scaled = scaler_model.transform(customers_features)
296
297
# K-means clustering
298
kmeans = KMeans(featuresCol="scaled_features", predictionCol="cluster", k=5, seed=42)
299
model = kmeans.fit(customers_scaled)
300
301
# Make predictions
302
predictions = model.transform(customers_scaled)
303
304
# Evaluate clustering
305
evaluator = ClusteringEvaluator(featuresCol="scaled_features", predictionCol="cluster")
306
silhouette = evaluator.evaluate(predictions)
307
print(f"Silhouette Score: {silhouette}")
308
309
# Analyze clusters
310
cluster_summary = predictions.groupBy("cluster").agg(
311
count("*").alias("cluster_size"),
312
avg("age").alias("avg_age"),
313
avg("annual_income").alias("avg_income"),
314
avg("spending_score").alias("avg_spending_score"),
315
avg("total_orders").alias("avg_orders")
316
)
317
318
cluster_summary.show()
319
320
# Save results
321
predictions.select("customer_id", "cluster", *feature_cols).write.mode("overwrite").parquet("s3://my-bucket/results/customer_segments/")
322
323
# Save cluster centers
324
centers_df = spark.createDataFrame([(i, center.toArray().tolist()) for i, center in enumerate(model.clusterCenters())],
325
["cluster_id", "center_coordinates"])
326
centers_df.write.mode("overwrite").json("s3://my-bucket/results/cluster_centers/")
327
328
spark.stop()
329
"""
330
331
cursor.execute(ml_code)
332
333
print("ML pipeline started")
334
print(f"Session: {cursor.session_id}")
335
print(f"Calculation: {cursor.calculation_id}")
336
337
# Monitor progress
338
while cursor.state in ['CREATING', 'CREATED', 'QUEUED', 'RUNNING']:
339
time.sleep(10)
340
# In practice, you'd poll the status here
341
print(f"Current state: {cursor.state}")
342
if cursor.progress:
343
print(f"Progress: {cursor.progress}")
344
345
print(f"Final state: {cursor.state}")
346
if cursor.state == 'COMPLETED':
347
print("ML pipeline completed successfully!")
348
print(f"Results saved to: s3://my-bucket/results/customer_segments/")
349
else:
350
print(f"Pipeline failed: {cursor.state_change_reason}")
351
352
cursor.close()
353
conn.close()
354
355
# Run ML pipeline
356
spark_ml_pipeline()
357
```
358
359
### Jupyter Notebook Integration
360
361
```python
362
from pyathena import connect
363
from pyathena.spark.cursor import SparkCursor
364
365
class JupyterSparkNotebook:
366
"""Integration with Jupyter-style notebook execution."""
367
368
def __init__(self, connection):
369
self.conn = connection
370
self.cursor = connection.cursor()
371
self.cell_results = []
372
373
def execute_cell(self, cell_code, cell_name=None):
374
"""Execute a notebook cell."""
375
print(f"Executing cell: {cell_name or 'Unnamed'}")
376
377
result = self.cursor.execute(cell_code)
378
379
cell_result = {
380
'name': cell_name,
381
'calculation_id': self.cursor.calculation_id,
382
'state': self.cursor.state,
383
'start_time': self.cursor.submission_date_time,
384
'std_out': self.cursor.std_out_s3_uri,
385
'std_error': self.cursor.std_error_s3_uri
386
}
387
388
self.cell_results.append(cell_result)
389
return result
390
391
def get_execution_summary(self):
392
"""Get summary of all cell executions."""
393
for result in self.cell_results:
394
print(f"Cell: {result['name']}")
395
print(f" State: {result['state']}")
396
print(f" Calculation ID: {result['calculation_id']}")
397
if result['std_out']:
398
print(f" Output: {result['std_out']}")
399
400
def close(self):
401
self.cursor.close()
402
self.conn.close()
403
404
# Example notebook workflow
405
def notebook_example():
406
conn = connect(
407
s3_staging_dir="s3://my-bucket/athena-results/",
408
region_name="us-west-2",
409
cursor_class=SparkCursor,
410
work_group="spark-workgroup"
411
)
412
413
notebook = JupyterSparkNotebook(conn)
414
415
# Cell 1: Setup and data loading
416
notebook.execute_cell("""
417
from pyspark.sql import SparkSession
418
spark = SparkSession.builder.appName("NotebookExample").getOrCreate()
419
420
# Load data
421
df = spark.read.parquet("s3://my-bucket/data/transactions/")
422
print(f"Loaded {df.count()} transactions")
423
df.printSchema()
424
""", "Data Loading")
425
426
# Cell 2: Data exploration
427
notebook.execute_cell("""
428
# Basic statistics
429
df.describe().show()
430
431
# Check for nulls
432
from pyspark.sql.functions import col, isnan, when, count
433
df.select([count(when(col(c).isNull() | isnan(col(c)), c)).alias(c) for c in df.columns]).show()
434
""", "Data Exploration")
435
436
# Cell 3: Analysis
437
notebook.execute_cell("""
438
from pyspark.sql.functions import sum, avg, max, min, count, date_format
439
440
# Monthly analysis
441
monthly_stats = df.groupBy(date_format("transaction_date", "yyyy-MM").alias("month")) \\
442
.agg(
443
sum("amount").alias("total_amount"),
444
avg("amount").alias("avg_amount"),
445
count("*").alias("transaction_count")
446
) \\
447
.orderBy("month")
448
449
monthly_stats.show()
450
451
# Save results
452
monthly_stats.write.mode("overwrite").parquet("s3://my-bucket/results/monthly_analysis/")
453
""", "Monthly Analysis")
454
455
# Get execution summary
456
notebook.get_execution_summary()
457
notebook.close()
458
459
notebook_example()
460
```
461
462
### Async Spark Operations
463
464
```python
465
import asyncio
466
from pyathena import connect
467
from pyathena.spark.async_cursor import AsyncSparkCursor
468
469
async def async_spark_operations():
470
conn = connect(
471
s3_staging_dir="s3://my-bucket/athena-results/",
472
region_name="us-west-2",
473
cursor_class=AsyncSparkCursor,
474
work_group="spark-workgroup"
475
)
476
477
cursor = conn.cursor()
478
479
# Execute multiple Spark jobs concurrently
480
jobs = [
481
("data_quality", """
482
spark = SparkSession.builder.appName("DataQuality").getOrCreate()
483
df = spark.read.parquet("s3://my-bucket/data/raw/")
484
quality_report = df.agg(*[count(when(col(c).isNull(), c)).alias(f"{c}_nulls") for c in df.columns])
485
quality_report.write.mode("overwrite").json("s3://my-bucket/reports/data_quality/")
486
"""),
487
("aggregation", """
488
spark = SparkSession.builder.appName("Aggregation").getOrCreate()
489
df = spark.read.parquet("s3://my-bucket/data/processed/")
490
summary = df.groupBy("category").agg(sum("amount"), count("*"))
491
summary.write.mode("overwrite").parquet("s3://my-bucket/results/category_summary/")
492
"""),
493
("feature_engineering", """
494
spark = SparkSession.builder.appName("FeatureEngineering").getOrCreate()
495
df = spark.read.parquet("s3://my-bucket/data/customers/")
496
from pyspark.ml.feature import Bucketizer, VectorAssembler
497
# Add feature engineering logic here
498
features = VectorAssembler(inputCols=["age", "income"], outputCol="features").transform(df)
499
features.write.mode("overwrite").parquet("s3://my-bucket/features/customer_features/")
500
""")
501
]
502
503
# Start all jobs
504
running_jobs = {}
505
for name, code in jobs:
506
calc_id, future = cursor.execute(code)
507
running_jobs[name] = {
508
'calculation_id': calc_id,
509
'future': future
510
}
511
print(f"Started {name} job (ID: {calc_id})")
512
513
# Wait for all jobs to complete
514
for name, job_info in running_jobs.items():
515
try:
516
result = await job_info['future']
517
print(f"✓ {name} completed successfully")
518
print(f" State: {result.state}")
519
if result.dpu_execution_in_millis:
520
print(f" Execution time: {result.dpu_execution_in_millis}ms")
521
except Exception as e:
522
print(f"✗ {name} failed: {e}")
523
524
cursor.close()
525
conn.close()
526
527
# Run async Spark operations
528
asyncio.run(async_spark_operations())
529
```
530
531
### Spark Configuration and Optimization
532
533
```python
534
from pyathena import connect
535
from pyathena.spark.cursor import SparkCursor
536
537
def optimized_spark_job():
538
# Connection with Spark-specific configuration
539
conn = connect(
540
s3_staging_dir="s3://my-bucket/athena-results/",
541
region_name="us-west-2",
542
cursor_class=SparkCursor,
543
work_group="spark-workgroup"
544
)
545
546
cursor = conn.cursor()
547
548
# Optimized Spark configuration
549
optimized_code = """
550
from pyspark.sql import SparkSession
551
from pyspark.sql.functions import *
552
553
# Configure Spark for performance
554
spark = SparkSession.builder \\
555
.appName("OptimizedAnalytics") \\
556
.config("spark.sql.adaptive.enabled", "true") \\
557
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \\
558
.config("spark.sql.adaptive.skewJoin.enabled", "true") \\
559
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \\
560
.getOrCreate()
561
562
# Enable dynamic partition pruning
563
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
564
565
# Read large dataset with partitioning
566
large_df = spark.read.option("mergeSchema", "true") \\
567
.parquet("s3://my-bucket/data/large_dataset/")
568
569
print(f"Dataset partitions: {large_df.rdd.getNumPartitions()}")
570
571
# Efficient aggregation with broadcast hint
572
lookup_df = spark.read.parquet("s3://my-bucket/data/lookup_table/")
573
broadcast_lookup = broadcast(lookup_df)
574
575
# Join with broadcast hint for small lookup table
576
result = large_df.join(broadcast_lookup, "key") \\
577
.groupBy("category", "region") \\
578
.agg(
579
sum("amount").alias("total_amount"),
580
count("*").alias("record_count"),
581
avg("amount").alias("avg_amount")
582
) \\
583
.cache() # Cache for multiple actions
584
585
# Write with optimal partitioning
586
result.coalesce(100) \\
587
.write \\
588
.mode("overwrite") \\
589
.option("compression", "snappy") \\
590
.partitionBy("region") \\
591
.parquet("s3://my-bucket/results/optimized_output/")
592
593
# Show execution plan
594
result.explain(True)
595
596
spark.stop()
597
"""
598
599
cursor.execute(optimized_code)
600
601
print("Optimized Spark job started")
602
print(f"Session ID: {cursor.session_id}")
603
print(f"Working directory: {cursor.working_directory}")
604
605
cursor.close()
606
conn.close()
607
608
optimized_spark_job()
609
```
610
611
## Configuration Requirements
612
613
To use Spark integration with PyAthena:
614
615
1. **Workgroup Configuration**: Your Athena workgroup must be configured for Spark
616
2. **IAM Permissions**: Required permissions for Spark execution and S3 access
617
3. **S3 Working Directory**: Configured S3 location for Spark session files
618
4. **Engine Version**: Compatible Athena engine version with Spark support
619
620
## Supported Spark Features
621
622
- **PySpark**: Python API for Spark
623
- **Spark SQL**: SQL interface for Spark
624
- **MLlib**: Machine learning library
625
- **Structured Streaming**: Stream processing (limited support)
626
- **DataFrame API**: High-level DataFrame operations
627
- **RDD API**: Low-level resilient distributed datasets
628
629
## Performance Considerations
630
631
- Use appropriate cluster sizing for your workgroup
632
- Leverage Spark's adaptive query execution
633
- Use broadcast joins for small lookup tables
634
- Partition large datasets appropriately
635
- Cache intermediate results when accessed multiple times
636
- Use columnar formats (Parquet) for better performance