or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

arrow-integration.mdasync-operations.mdcore-database.mdindex.mdpandas-integration.mdspark-integration.mdsqlalchemy-integration.md

spark-integration.mddocs/

0

# Spark Integration

1

2

Integration with Athena's Spark execution engine for distributed processing, Jupyter notebook compatibility, and advanced analytics workloads. Enables running Spark code directly on Athena's serverless Spark engine.

3

4

## Capabilities

5

6

### Spark Cursor

7

8

Cursor for executing Spark calculations on Athena's Spark engine, providing distributed processing capabilities and integration with Jupyter notebooks.

9

10

```python { .api }

11

class SparkCursor:

12

session_id: str

13

calculation_id: Optional[str]

14

description: Optional[str]

15

working_directory: Optional[str]

16

state: Optional[str]

17

state_change_reason: Optional[str]

18

submission_date_time: Optional[datetime]

19

completion_date_time: Optional[datetime]

20

dpu_execution_in_millis: Optional[int]

21

progress: Optional[str]

22

std_out_s3_uri: Optional[str]

23

std_error_s3_uri: Optional[str]

24

result_s3_uri: Optional[str]

25

result_type: Optional[str]

26

27

def execute(self, code: str, **kwargs) -> SparkCursor:

28

"""

29

Execute Spark code on Athena's Spark engine.

30

31

Parameters:

32

- code: Spark code to execute (Python, Scala, or SQL)

33

- **kwargs: Additional execution options

34

35

Returns:

36

Self for method chaining

37

"""

38

39

def cancel(self) -> None:

40

"""Cancel the currently executing Spark calculation."""

41

42

def close(self) -> None:

43

"""Close the cursor and clean up resources."""

44

45

@property

46

def calculation_execution(self) -> Optional[AthenaCalculationExecution]:

47

"""Get the current calculation execution metadata."""

48

```

49

50

### Async Spark Cursor

51

52

Asynchronous version of SparkCursor for non-blocking Spark calculations.

53

54

```python { .api }

55

class AsyncSparkCursor:

56

def execute(self, code: str, **kwargs) -> Tuple[str, Future[AthenaCalculationExecution]]:

57

"""

58

Execute Spark code asynchronously.

59

60

Parameters:

61

- code: Spark code to execute

62

63

Returns:

64

Tuple of (calculation_id, Future[AthenaCalculationExecution])

65

"""

66

67

def cancel(self, calculation_id: str) -> Future[None]:

68

"""Cancel Spark calculation by ID asynchronously."""

69

70

def close(self, wait: bool = False) -> None:

71

"""Close cursor, optionally waiting for running calculations."""

72

```

73

74

### Spark Calculation Models

75

76

Models representing Spark calculation execution and status information.

77

78

```python { .api }

79

class AthenaCalculationExecution:

80

# Calculation states

81

STATE_CREATING: str = "CREATING"

82

STATE_CREATED: str = "CREATED"

83

STATE_QUEUED: str = "QUEUED"

84

STATE_RUNNING: str = "RUNNING"

85

STATE_CANCELLING: str = "CANCELLING"

86

STATE_CANCELLED: str = "CANCELLED"

87

STATE_COMPLETED: str = "COMPLETED"

88

STATE_FAILED: str = "FAILED"

89

90

calculation_execution_id: Optional[str]

91

session_id: Optional[str]

92

description: Optional[str]

93

working_directory: Optional[str]

94

state: Optional[str]

95

state_change_reason: Optional[str]

96

submission_date_time: Optional[datetime]

97

completion_date_time: Optional[datetime]

98

dpu_execution_in_millis: Optional[int]

99

progress: Optional[str]

100

std_out_s3_uri: Optional[str]

101

std_error_s3_uri: Optional[str]

102

result_s3_uri: Optional[str]

103

result_type: Optional[str]

104

105

class AthenaSessionStatus:

106

session_id: Optional[str]

107

description: Optional[str]

108

working_directory: Optional[str]

109

idle_since_date_time: Optional[datetime]

110

last_modified_date_time: Optional[datetime]

111

termination_date_time: Optional[datetime]

112

notebook_version: Optional[str]

113

session_configuration: Optional[Dict]

114

status: Optional[Dict]

115

```

116

117

## Usage Examples

118

119

### Basic Spark Code Execution

120

121

```python

122

from pyathena import connect

123

from pyathena.spark.cursor import SparkCursor

124

125

# Connect with Spark cursor

126

conn = connect(

127

s3_staging_dir="s3://my-bucket/athena-results/",

128

region_name="us-west-2",

129

cursor_class=SparkCursor,

130

work_group="spark-workgroup" # Must be configured for Spark

131

)

132

133

cursor = conn.cursor()

134

135

# Execute Spark Python code

136

spark_code = """

137

from pyspark.sql import SparkSession

138

139

spark = SparkSession.builder.appName("AthenaSparkExample").getOrCreate()

140

141

# Read data from S3

142

df = spark.read.parquet("s3://my-bucket/data/sales/")

143

144

# Perform transformations

145

result = df.groupBy("product_category") \\

146

.agg({"amount": "sum", "quantity": "count"}) \\

147

.orderBy("sum(amount)", ascending=False)

148

149

result.show()

150

151

# Save results

152

result.write.mode("overwrite").parquet("s3://my-bucket/results/category_summary/")

153

"""

154

155

cursor.execute(spark_code)

156

157

# Check execution status

158

print(f"Session ID: {cursor.session_id}")

159

print(f"Calculation ID: {cursor.calculation_id}")

160

print(f"State: {cursor.state}")

161

print(f"Progress: {cursor.progress}")

162

163

# Get output and error logs

164

if cursor.std_out_s3_uri:

165

print(f"Output logs: {cursor.std_out_s3_uri}")

166

if cursor.std_error_s3_uri:

167

print(f"Error logs: {cursor.std_error_s3_uri}")

168

169

cursor.close()

170

conn.close()

171

```

172

173

### Spark SQL Execution

174

175

```python

176

from pyathena import connect

177

from pyathena.spark.cursor import SparkCursor

178

179

conn = connect(

180

s3_staging_dir="s3://my-bucket/athena-results/",

181

region_name="us-west-2",

182

cursor_class=SparkCursor,

183

work_group="spark-workgroup"

184

)

185

186

cursor = conn.cursor()

187

188

# Execute Spark SQL

189

spark_sql = """

190

-- Create temporary view from S3 data

191

CREATE OR REPLACE TEMPORARY VIEW sales

192

USING PARQUET

193

OPTIONS (

194

path "s3://my-bucket/data/sales/"

195

);

196

197

-- Perform complex analytics

198

WITH monthly_metrics AS (

199

SELECT

200

DATE_FORMAT(sale_date, 'yyyy-MM') as sale_month,

201

product_category,

202

SUM(amount) as total_revenue,

203

COUNT(*) as transaction_count,

204

COUNT(DISTINCT customer_id) as unique_customers

205

FROM sales

206

WHERE sale_date >= '2023-01-01'

207

GROUP BY DATE_FORMAT(sale_date, 'yyyy-MM'), product_category

208

),

209

category_trends AS (

210

SELECT

211

product_category,

212

sale_month,

213

total_revenue,

214

LAG(total_revenue) OVER (

215

PARTITION BY product_category

216

ORDER BY sale_month

217

) as prev_month_revenue,

218

total_revenue - LAG(total_revenue) OVER (

219

PARTITION BY product_category

220

ORDER BY sale_month

221

) as revenue_change

222

FROM monthly_metrics

223

)

224

SELECT

225

product_category,

226

sale_month,

227

total_revenue,

228

revenue_change,

229

CASE

230

WHEN revenue_change > 0 THEN 'Growth'

231

WHEN revenue_change < 0 THEN 'Decline'

232

ELSE 'Stable'

233

END as trend

234

FROM category_trends

235

WHERE prev_month_revenue IS NOT NULL

236

ORDER BY product_category, sale_month;

237

"""

238

239

cursor.execute(spark_sql)

240

241

print(f"Spark SQL execution started")

242

print(f"Calculation ID: {cursor.calculation_id}")

243

print(f"Working directory: {cursor.working_directory}")

244

245

cursor.close()

246

conn.close()

247

```

248

249

### Advanced Spark Analytics

250

251

```python

252

from pyathena import connect

253

from pyathena.spark.cursor import SparkCursor

254

255

def spark_ml_pipeline():

256

conn = connect(

257

s3_staging_dir="s3://my-bucket/athena-results/",

258

region_name="us-west-2",

259

cursor_class=SparkCursor,

260

work_group="spark-workgroup"

261

)

262

263

cursor = conn.cursor()

264

265

# Machine learning pipeline with Spark MLlib

266

ml_code = """

267

from pyspark.sql import SparkSession

268

from pyspark.ml.feature import VectorAssembler, StandardScaler

269

from pyspark.ml.clustering import KMeans

270

from pyspark.ml.evaluation import ClusteringEvaluator

271

from pyspark.sql.functions import col, when, isnan, count

272

273

spark = SparkSession.builder.appName("CustomerSegmentation").getOrCreate()

274

275

# Load customer data

276

customers = spark.read.parquet("s3://my-bucket/data/customers/")

277

278

# Data preprocessing

279

print("Original dataset shape:", customers.count(), len(customers.columns))

280

281

# Handle missing values

282

customers_clean = customers.dropna()

283

print("After removing nulls:", customers_clean.count())

284

285

# Feature engineering

286

feature_cols = ["age", "annual_income", "spending_score", "total_orders", "avg_order_value"]

287

288

# Create feature vector

289

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

290

customers_features = assembler.transform(customers_clean)

291

292

# Scale features

293

scaler = StandardScaler(inputCol="features", outputCol="scaled_features")

294

scaler_model = scaler.fit(customers_features)

295

customers_scaled = scaler_model.transform(customers_features)

296

297

# K-means clustering

298

kmeans = KMeans(featuresCol="scaled_features", predictionCol="cluster", k=5, seed=42)

299

model = kmeans.fit(customers_scaled)

300

301

# Make predictions

302

predictions = model.transform(customers_scaled)

303

304

# Evaluate clustering

305

evaluator = ClusteringEvaluator(featuresCol="scaled_features", predictionCol="cluster")

306

silhouette = evaluator.evaluate(predictions)

307

print(f"Silhouette Score: {silhouette}")

308

309

# Analyze clusters

310

cluster_summary = predictions.groupBy("cluster").agg(

311

count("*").alias("cluster_size"),

312

avg("age").alias("avg_age"),

313

avg("annual_income").alias("avg_income"),

314

avg("spending_score").alias("avg_spending_score"),

315

avg("total_orders").alias("avg_orders")

316

)

317

318

cluster_summary.show()

319

320

# Save results

321

predictions.select("customer_id", "cluster", *feature_cols).write.mode("overwrite").parquet("s3://my-bucket/results/customer_segments/")

322

323

# Save cluster centers

324

centers_df = spark.createDataFrame([(i, center.toArray().tolist()) for i, center in enumerate(model.clusterCenters())],

325

["cluster_id", "center_coordinates"])

326

centers_df.write.mode("overwrite").json("s3://my-bucket/results/cluster_centers/")

327

328

spark.stop()

329

"""

330

331

cursor.execute(ml_code)

332

333

print("ML pipeline started")

334

print(f"Session: {cursor.session_id}")

335

print(f"Calculation: {cursor.calculation_id}")

336

337

# Monitor progress

338

while cursor.state in ['CREATING', 'CREATED', 'QUEUED', 'RUNNING']:

339

time.sleep(10)

340

# In practice, you'd poll the status here

341

print(f"Current state: {cursor.state}")

342

if cursor.progress:

343

print(f"Progress: {cursor.progress}")

344

345

print(f"Final state: {cursor.state}")

346

if cursor.state == 'COMPLETED':

347

print("ML pipeline completed successfully!")

348

print(f"Results saved to: s3://my-bucket/results/customer_segments/")

349

else:

350

print(f"Pipeline failed: {cursor.state_change_reason}")

351

352

cursor.close()

353

conn.close()

354

355

# Run ML pipeline

356

spark_ml_pipeline()

357

```

358

359

### Jupyter Notebook Integration

360

361

```python

362

from pyathena import connect

363

from pyathena.spark.cursor import SparkCursor

364

365

class JupyterSparkNotebook:

366

"""Integration with Jupyter-style notebook execution."""

367

368

def __init__(self, connection):

369

self.conn = connection

370

self.cursor = connection.cursor()

371

self.cell_results = []

372

373

def execute_cell(self, cell_code, cell_name=None):

374

"""Execute a notebook cell."""

375

print(f"Executing cell: {cell_name or 'Unnamed'}")

376

377

result = self.cursor.execute(cell_code)

378

379

cell_result = {

380

'name': cell_name,

381

'calculation_id': self.cursor.calculation_id,

382

'state': self.cursor.state,

383

'start_time': self.cursor.submission_date_time,

384

'std_out': self.cursor.std_out_s3_uri,

385

'std_error': self.cursor.std_error_s3_uri

386

}

387

388

self.cell_results.append(cell_result)

389

return result

390

391

def get_execution_summary(self):

392

"""Get summary of all cell executions."""

393

for result in self.cell_results:

394

print(f"Cell: {result['name']}")

395

print(f" State: {result['state']}")

396

print(f" Calculation ID: {result['calculation_id']}")

397

if result['std_out']:

398

print(f" Output: {result['std_out']}")

399

400

def close(self):

401

self.cursor.close()

402

self.conn.close()

403

404

# Example notebook workflow

405

def notebook_example():

406

conn = connect(

407

s3_staging_dir="s3://my-bucket/athena-results/",

408

region_name="us-west-2",

409

cursor_class=SparkCursor,

410

work_group="spark-workgroup"

411

)

412

413

notebook = JupyterSparkNotebook(conn)

414

415

# Cell 1: Setup and data loading

416

notebook.execute_cell("""

417

from pyspark.sql import SparkSession

418

spark = SparkSession.builder.appName("NotebookExample").getOrCreate()

419

420

# Load data

421

df = spark.read.parquet("s3://my-bucket/data/transactions/")

422

print(f"Loaded {df.count()} transactions")

423

df.printSchema()

424

""", "Data Loading")

425

426

# Cell 2: Data exploration

427

notebook.execute_cell("""

428

# Basic statistics

429

df.describe().show()

430

431

# Check for nulls

432

from pyspark.sql.functions import col, isnan, when, count

433

df.select([count(when(col(c).isNull() | isnan(col(c)), c)).alias(c) for c in df.columns]).show()

434

""", "Data Exploration")

435

436

# Cell 3: Analysis

437

notebook.execute_cell("""

438

from pyspark.sql.functions import sum, avg, max, min, count, date_format

439

440

# Monthly analysis

441

monthly_stats = df.groupBy(date_format("transaction_date", "yyyy-MM").alias("month")) \\

442

.agg(

443

sum("amount").alias("total_amount"),

444

avg("amount").alias("avg_amount"),

445

count("*").alias("transaction_count")

446

) \\

447

.orderBy("month")

448

449

monthly_stats.show()

450

451

# Save results

452

monthly_stats.write.mode("overwrite").parquet("s3://my-bucket/results/monthly_analysis/")

453

""", "Monthly Analysis")

454

455

# Get execution summary

456

notebook.get_execution_summary()

457

notebook.close()

458

459

notebook_example()

460

```

461

462

### Async Spark Operations

463

464

```python

465

import asyncio

466

from pyathena import connect

467

from pyathena.spark.async_cursor import AsyncSparkCursor

468

469

async def async_spark_operations():

470

conn = connect(

471

s3_staging_dir="s3://my-bucket/athena-results/",

472

region_name="us-west-2",

473

cursor_class=AsyncSparkCursor,

474

work_group="spark-workgroup"

475

)

476

477

cursor = conn.cursor()

478

479

# Execute multiple Spark jobs concurrently

480

jobs = [

481

("data_quality", """

482

spark = SparkSession.builder.appName("DataQuality").getOrCreate()

483

df = spark.read.parquet("s3://my-bucket/data/raw/")

484

quality_report = df.agg(*[count(when(col(c).isNull(), c)).alias(f"{c}_nulls") for c in df.columns])

485

quality_report.write.mode("overwrite").json("s3://my-bucket/reports/data_quality/")

486

"""),

487

("aggregation", """

488

spark = SparkSession.builder.appName("Aggregation").getOrCreate()

489

df = spark.read.parquet("s3://my-bucket/data/processed/")

490

summary = df.groupBy("category").agg(sum("amount"), count("*"))

491

summary.write.mode("overwrite").parquet("s3://my-bucket/results/category_summary/")

492

"""),

493

("feature_engineering", """

494

spark = SparkSession.builder.appName("FeatureEngineering").getOrCreate()

495

df = spark.read.parquet("s3://my-bucket/data/customers/")

496

from pyspark.ml.feature import Bucketizer, VectorAssembler

497

# Add feature engineering logic here

498

features = VectorAssembler(inputCols=["age", "income"], outputCol="features").transform(df)

499

features.write.mode("overwrite").parquet("s3://my-bucket/features/customer_features/")

500

""")

501

]

502

503

# Start all jobs

504

running_jobs = {}

505

for name, code in jobs:

506

calc_id, future = cursor.execute(code)

507

running_jobs[name] = {

508

'calculation_id': calc_id,

509

'future': future

510

}

511

print(f"Started {name} job (ID: {calc_id})")

512

513

# Wait for all jobs to complete

514

for name, job_info in running_jobs.items():

515

try:

516

result = await job_info['future']

517

print(f"✓ {name} completed successfully")

518

print(f" State: {result.state}")

519

if result.dpu_execution_in_millis:

520

print(f" Execution time: {result.dpu_execution_in_millis}ms")

521

except Exception as e:

522

print(f"✗ {name} failed: {e}")

523

524

cursor.close()

525

conn.close()

526

527

# Run async Spark operations

528

asyncio.run(async_spark_operations())

529

```

530

531

### Spark Configuration and Optimization

532

533

```python

534

from pyathena import connect

535

from pyathena.spark.cursor import SparkCursor

536

537

def optimized_spark_job():

538

# Connection with Spark-specific configuration

539

conn = connect(

540

s3_staging_dir="s3://my-bucket/athena-results/",

541

region_name="us-west-2",

542

cursor_class=SparkCursor,

543

work_group="spark-workgroup"

544

)

545

546

cursor = conn.cursor()

547

548

# Optimized Spark configuration

549

optimized_code = """

550

from pyspark.sql import SparkSession

551

from pyspark.sql.functions import *

552

553

# Configure Spark for performance

554

spark = SparkSession.builder \\

555

.appName("OptimizedAnalytics") \\

556

.config("spark.sql.adaptive.enabled", "true") \\

557

.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \\

558

.config("spark.sql.adaptive.skewJoin.enabled", "true") \\

559

.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \\

560

.getOrCreate()

561

562

# Enable dynamic partition pruning

563

spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")

564

565

# Read large dataset with partitioning

566

large_df = spark.read.option("mergeSchema", "true") \\

567

.parquet("s3://my-bucket/data/large_dataset/")

568

569

print(f"Dataset partitions: {large_df.rdd.getNumPartitions()}")

570

571

# Efficient aggregation with broadcast hint

572

lookup_df = spark.read.parquet("s3://my-bucket/data/lookup_table/")

573

broadcast_lookup = broadcast(lookup_df)

574

575

# Join with broadcast hint for small lookup table

576

result = large_df.join(broadcast_lookup, "key") \\

577

.groupBy("category", "region") \\

578

.agg(

579

sum("amount").alias("total_amount"),

580

count("*").alias("record_count"),

581

avg("amount").alias("avg_amount")

582

) \\

583

.cache() # Cache for multiple actions

584

585

# Write with optimal partitioning

586

result.coalesce(100) \\

587

.write \\

588

.mode("overwrite") \\

589

.option("compression", "snappy") \\

590

.partitionBy("region") \\

591

.parquet("s3://my-bucket/results/optimized_output/")

592

593

# Show execution plan

594

result.explain(True)

595

596

spark.stop()

597

"""

598

599

cursor.execute(optimized_code)

600

601

print("Optimized Spark job started")

602

print(f"Session ID: {cursor.session_id}")

603

print(f"Working directory: {cursor.working_directory}")

604

605

cursor.close()

606

conn.close()

607

608

optimized_spark_job()

609

```

610

611

## Configuration Requirements

612

613

To use Spark integration with PyAthena:

614

615

1. **Workgroup Configuration**: Your Athena workgroup must be configured for Spark

616

2. **IAM Permissions**: Required permissions for Spark execution and S3 access

617

3. **S3 Working Directory**: Configured S3 location for Spark session files

618

4. **Engine Version**: Compatible Athena engine version with Spark support

619

620

## Supported Spark Features

621

622

- **PySpark**: Python API for Spark

623

- **Spark SQL**: SQL interface for Spark

624

- **MLlib**: Machine learning library

625

- **Structured Streaming**: Stream processing (limited support)

626

- **DataFrame API**: High-level DataFrame operations

627

- **RDD API**: Low-level resilient distributed datasets

628

629

## Performance Considerations

630

631

- Use appropriate cluster sizing for your workgroup

632

- Leverage Spark's adaptive query execution

633

- Use broadcast joins for small lookup tables

634

- Partition large datasets appropriately

635

- Cache intermediate results when accessed multiple times

636

- Use columnar formats (Parquet) for better performance