0
# Plugin System
1
2
Extensible plugin framework supporting diverse compute requirements including Apache Spark, machine learning frameworks (TensorFlow, PyTorch), distributed computing (Dask, Ray), database queries, and specialized execution patterns. The plugin system enables Flyte to integrate with various execution backends while maintaining a consistent interface.
3
4
## Capabilities
5
6
### Apache Spark Jobs
7
8
Execute Apache Spark applications with comprehensive configuration support for different Spark application types and cluster management.
9
10
```python { .api }
11
class SparkJob:
12
"""Apache Spark job configuration."""
13
spark_conf: dict[str, str]
14
application_file: str
15
executor_path: str
16
main_application_file: str
17
main_class: str
18
spark_type: SparkType
19
databricks_conf: DatabricksConf
20
databricks_token: str
21
databricks_instance: str
22
23
class SparkType:
24
"""Spark application type enumeration."""
25
PYTHON = 0
26
JAVA = 1
27
SCALA = 2
28
R = 3
29
30
class DatabricksConf:
31
"""Databricks-specific configuration."""
32
databricks_token: str
33
databricks_instance: str
34
tags: dict[str, str]
35
```
36
37
### PyTorch Distributed Training
38
39
Execute PyTorch distributed training jobs with support for various distributed training strategies.
40
41
```python { .api }
42
class PyTorchJob:
43
"""PyTorch distributed training job configuration."""
44
workers: int
45
master_replicas: DistributedPyTorchTrainingReplicaSpec
46
worker_replicas: DistributedPyTorchTrainingReplicaSpec
47
48
class DistributedPyTorchTrainingReplicaSpec:
49
"""Replica specification for PyTorch distributed training."""
50
replicas: int
51
image: str
52
resources: k8s.ResourceRequirements
53
restart_policy: RestartPolicy
54
common: CommonReplicaSpec
55
56
class RestartPolicy:
57
"""Restart policy enumeration."""
58
NEVER = 0
59
ON_FAILURE = 1
60
ALWAYS = 2
61
62
class CommonReplicaSpec:
63
"""Common replica specification."""
64
replicas: int
65
template: k8s.PodTemplateSpec
66
restart_policy: RestartPolicy
67
```
68
69
### TensorFlow Distributed Training
70
71
Execute TensorFlow distributed training jobs with parameter server and worker configuration.
72
73
```python { .api }
74
class TensorFlowJob:
75
"""TensorFlow distributed training job configuration."""
76
workers: int
77
ps_replicas: DistributedTensorFlowTrainingReplicaSpec
78
chief_replicas: DistributedTensorFlowTrainingReplicaSpec
79
worker_replicas: DistributedTensorFlowTrainingReplicaSpec
80
evaluator_replicas: DistributedTensorFlowTrainingReplicaSpec
81
run_policy: RunPolicy
82
83
class DistributedTensorFlowTrainingReplicaSpec:
84
"""Replica specification for TensorFlow distributed training."""
85
replicas: int
86
image: str
87
resources: k8s.ResourceRequirements
88
restart_policy: RestartPolicy
89
common: CommonReplicaSpec
90
91
class RunPolicy:
92
"""Run policy for training jobs."""
93
clean_pod_policy: CleanPodPolicy
94
ttl_seconds_after_finished: int
95
active_deadline_seconds: int
96
backoff_limit: int
97
98
class CleanPodPolicy:
99
"""Pod cleanup policy enumeration."""
100
NONE = 0
101
ALL = 1
102
RUNNING = 2
103
SUCCEEDED = 3
104
```
105
106
### MPI Jobs
107
108
Execute MPI (Message Passing Interface) distributed computing jobs for high-performance computing workloads.
109
110
```python { .api }
111
class MpiJob:
112
"""MPI distributed computing job configuration."""
113
slots: int
114
replicas: int
115
launcher_replicas: MpiReplicaSpec
116
worker_replicas: MpiReplicaSpec
117
run_policy: RunPolicy
118
119
class MpiReplicaSpec:
120
"""MPI replica specification."""
121
replicas: int
122
image: str
123
resources: k8s.ResourceRequirements
124
restart_policy: RestartPolicy
125
common: CommonReplicaSpec
126
```
127
128
### Ray Distributed Computing
129
130
Execute Ray distributed computing jobs with cluster configuration and resource management.
131
132
```python { .api }
133
class RayJob:
134
"""Ray distributed computing job configuration."""
135
ray_cluster: RayCluster
136
runtime_env: str
137
ttl_seconds_after_finished: int
138
shutdown_after_job_finishes: bool
139
enable_autoscaling: bool
140
141
class RayCluster:
142
"""Ray cluster configuration."""
143
head_group_spec: HeadGroupSpec
144
worker_group_spec: list[WorkerGroupSpec]
145
enable_autoscaling: bool
146
autoscaler_options: AutoscalerOptions
147
148
class HeadGroupSpec:
149
"""Ray head node specification."""
150
compute_template: str
151
image: str
152
service_type: str
153
enable_ingress: bool
154
ray_start_params: dict[str, str]
155
156
class WorkerGroupSpec:
157
"""Ray worker group specification."""
158
group_name: str
159
compute_template: str
160
image: str
161
replicas: int
162
min_replicas: int
163
max_replicas: int
164
ray_start_params: dict[str, str]
165
166
class AutoscalerOptions:
167
"""Ray autoscaler configuration."""
168
upscaling_speed: float
169
downscaling_speed: float
170
idle_timeout_seconds: int
171
```
172
173
### Dask Distributed Computing
174
175
Execute Dask distributed computing jobs with scheduler and worker configuration.
176
177
```python { .api }
178
class DaskJob:
179
"""Dask distributed computing job configuration."""
180
scheduler: DaskScheduler
181
workers: DaskWorkerGroup
182
183
class DaskScheduler:
184
"""Dask scheduler configuration."""
185
image: str
186
resources: k8s.ResourceRequirements
187
cluster: DaskCluster
188
189
class DaskWorkerGroup:
190
"""Dask worker group configuration."""
191
number_of_workers: int
192
image: str
193
resources: k8s.ResourceRequirements
194
195
class DaskCluster:
196
"""Dask cluster configuration."""
197
n_workers: int
198
threads_per_worker: int
199
scheduler_resources: k8s.ResourceRequirements
200
worker_resources: k8s.ResourceRequirements
201
```
202
203
### Array Jobs
204
205
Execute array jobs with parallel task execution and success criteria configuration.
206
207
```python { .api }
208
class ArrayJob:
209
"""Array job configuration for parallel execution."""
210
parallelism: int
211
size: int
212
min_successes: int
213
min_success_ratio: float
214
success_policy: SuccessPolicy
215
216
class SuccessPolicy:
217
"""Success policy for array jobs."""
218
min_successes: int
219
min_success_ratio: float
220
```
221
222
### Database Query Jobs
223
224
Execute SQL queries against various database systems with dialect-specific support.
225
226
```python { .api }
227
class PrestoJob:
228
"""Presto SQL query job configuration."""
229
statement: str
230
query_properties: dict[str, str]
231
routing_group: str
232
catalog: str
233
schema: str
234
235
class QuboleJob:
236
"""Qubole data platform job configuration."""
237
tags: list[str]
238
cluster_label: str
239
sdk_version: str
240
```
241
242
### Waitable Tasks
243
244
Tasks that wait for external conditions or resources before proceeding.
245
246
```python { .api }
247
class WaitableInterface:
248
"""Interface for tasks that wait for external conditions."""
249
wakeup_policy: WakeupPolicy
250
sleep_policy: SleepPolicy
251
252
class WakeupPolicy:
253
"""Policy for waking up waiting tasks."""
254
pass
255
256
class SleepPolicy:
257
"""Policy for sleeping tasks."""
258
pass
259
```
260
261
### Kubeflow Integration
262
263
Deep integration with Kubeflow for machine learning workflow orchestration.
264
265
```python { .api }
266
# Kubeflow TensorFlow Job
267
class TensorFlowJob:
268
"""Kubeflow TensorFlow training job."""
269
workers: int
270
ps_replicas: DistributedTensorFlowTrainingReplicaSpec
271
chief_replicas: DistributedTensorFlowTrainingReplicaSpec
272
worker_replicas: DistributedTensorFlowTrainingReplicaSpec
273
evaluator_replicas: DistributedTensorFlowTrainingReplicaSpec
274
run_policy: RunPolicy
275
276
# Kubeflow PyTorch Job
277
class PyTorchJob:
278
"""Kubeflow PyTorch training job."""
279
workers: int
280
master_replicas: DistributedPyTorchTrainingReplicaSpec
281
worker_replicas: DistributedPyTorchTrainingReplicaSpec
282
283
# Kubeflow MPI Job
284
class MpiJob:
285
"""Kubeflow MPI distributed computing job."""
286
slots: int
287
replicas: int
288
launcher_replicas: MpiReplicaSpec
289
worker_replicas: MpiReplicaSpec
290
run_policy: RunPolicy
291
```
292
293
## Usage Examples
294
295
### Spark Job Configuration
296
297
```python
298
from flyteidl.plugins import spark_pb2
299
300
# Configure Spark application
301
spark_job = spark_pb2.SparkJob(
302
spark_type=spark_pb2.SparkType.PYTHON,
303
application_file="s3://my-bucket/spark-app.py",
304
spark_conf={
305
"spark.executor.memory": "4g",
306
"spark.executor.cores": "2",
307
"spark.executor.instances": "10",
308
"spark.driver.memory": "2g",
309
"spark.driver.cores": "1",
310
"spark.sql.adaptive.enabled": "true",
311
"spark.sql.adaptive.coalescePartitions.enabled": "true"
312
}
313
)
314
315
# Use in task template custom field
316
task_template = TaskTemplate(
317
id=task_id,
318
type="spark",
319
interface=interface,
320
custom=MessageToDict(spark_job)
321
)
322
```
323
324
### PyTorch Distributed Training
325
326
```python
327
from flyteidl.plugins.kubeflow import pytorch_pb2
328
329
# Configure PyTorch distributed training
330
pytorch_job = pytorch_pb2.PyTorchJob(
331
workers=4,
332
master_replicas=pytorch_pb2.DistributedPyTorchTrainingReplicaSpec(
333
replicas=1,
334
image="pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime",
335
resources=k8s_pb2.ResourceRequirements(
336
requests={
337
"nvidia.com/gpu": "1",
338
"cpu": "2",
339
"memory": "8Gi"
340
}
341
)
342
),
343
worker_replicas=pytorch_pb2.DistributedPyTorchTrainingReplicaSpec(
344
replicas=3,
345
image="pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime",
346
resources=k8s_pb2.ResourceRequirements(
347
requests={
348
"nvidia.com/gpu": "1",
349
"cpu": "2",
350
"memory": "8Gi"
351
}
352
)
353
)
354
)
355
```
356
357
### Ray Cluster Configuration
358
359
```python
360
from flyteidl.plugins import ray_pb2
361
362
# Configure Ray cluster
363
ray_job = ray_pb2.RayJob(
364
ray_cluster=ray_pb2.RayCluster(
365
head_group_spec=ray_pb2.HeadGroupSpec(
366
compute_template="head-template",
367
image="rayproject/ray:2.0.0",
368
ray_start_params={
369
"dashboard-host": "0.0.0.0",
370
"metrics-export-port": "8080"
371
}
372
),
373
worker_group_spec=[
374
ray_pb2.WorkerGroupSpec(
375
group_name="workers",
376
compute_template="worker-template",
377
image="rayproject/ray:2.0.0",
378
replicas=5,
379
min_replicas=2,
380
max_replicas=10,
381
ray_start_params={
382
"metrics-export-port": "8080"
383
}
384
)
385
],
386
enable_autoscaling=True,
387
autoscaler_options=ray_pb2.AutoscalerOptions(
388
upscaling_speed=1.0,
389
downscaling_speed=0.5,
390
idle_timeout_seconds=60
391
)
392
),
393
runtime_env='{"pip": ["pandas", "numpy"]}',
394
enable_autoscaling=True,
395
shutdown_after_job_finishes=True
396
)
397
```
398
399
### Array Job Configuration
400
401
```python
402
from flyteidl.plugins import array_job_pb2
403
404
# Configure array job for parallel processing
405
array_job = array_job_pb2.ArrayJob(
406
parallelism=10, # Maximum concurrent executions
407
size=100, # Total number of subjobs
408
min_successes=90, # Minimum successful completions
409
min_success_ratio=0.9, # Minimum success ratio
410
success_policy=array_job_pb2.SuccessPolicy(
411
min_successes=90,
412
min_success_ratio=0.9
413
)
414
)
415
```
416
417
### SQL Query Jobs
418
419
```python
420
from flyteidl.plugins import presto_pb2
421
422
# Configure Presto query
423
presto_job = presto_pb2.PrestoJob(
424
statement="""
425
SELECT customer_id, COUNT(*) as order_count
426
FROM orders
427
WHERE order_date >= DATE '2023-01-01'
428
GROUP BY customer_id
429
ORDER BY order_count DESC
430
LIMIT 100
431
""",
432
query_properties={
433
"query.max-memory": "10GB",
434
"query.max-memory-per-node": "2GB"
435
},
436
routing_group="batch",
437
catalog="hive",
438
schema="analytics"
439
)
440
```
441
442
### Dask Distributed Computing
443
444
```python
445
from flyteidl.plugins import dask_pb2
446
447
# Configure Dask job
448
dask_job = dask_pb2.DaskJob(
449
scheduler=dask_pb2.DaskScheduler(
450
image="daskdev/dask:latest",
451
resources=k8s_pb2.ResourceRequirements(
452
requests={
453
"cpu": "1",
454
"memory": "2Gi"
455
}
456
)
457
),
458
workers=dask_pb2.DaskWorkerGroup(
459
number_of_workers=5,
460
image="daskdev/dask:latest",
461
resources=k8s_pb2.ResourceRequirements(
462
requests={
463
"cpu": "2",
464
"memory": "4Gi"
465
}
466
)
467
)
468
)
469
```
470
471
### TensorFlow Distributed Training
472
473
```python
474
from flyteidl.plugins.kubeflow import tensorflow_pb2
475
476
# Configure TensorFlow distributed training
477
tensorflow_job = tensorflow_pb2.TensorFlowJob(
478
workers=4,
479
ps_replicas=tensorflow_pb2.DistributedTensorFlowTrainingReplicaSpec(
480
replicas=2,
481
image="tensorflow/tensorflow:2.8.0-gpu",
482
resources=k8s_pb2.ResourceRequirements(
483
requests={
484
"cpu": "2",
485
"memory": "4Gi"
486
}
487
)
488
),
489
chief_replicas=tensorflow_pb2.DistributedTensorFlowTrainingReplicaSpec(
490
replicas=1,
491
image="tensorflow/tensorflow:2.8.0-gpu",
492
resources=k8s_pb2.ResourceRequirements(
493
requests={
494
"nvidia.com/gpu": "1",
495
"cpu": "4",
496
"memory": "8Gi"
497
}
498
)
499
),
500
worker_replicas=tensorflow_pb2.DistributedTensorFlowTrainingReplicaSpec(
501
replicas=3,
502
image="tensorflow/tensorflow:2.8.0-gpu",
503
resources=k8s_pb2.ResourceRequirements(
504
requests={
505
"nvidia.com/gpu": "1",
506
"cpu": "4",
507
"memory": "8Gi"
508
}
509
)
510
)
511
)
512
```