0
# SageMaker Machine Learning
1
2
Amazon SageMaker integration for end-to-end machine learning workflows including model training, tuning, deployment, and batch inference. Provides comprehensive MLOps capabilities for building, training, and deploying ML models at scale.
3
4
## Capabilities
5
6
### SageMaker Hook
7
8
Core SageMaker client providing ML lifecycle management functionality.
9
10
```python { .api }
11
class SageMakerHook(AwsBaseHook):
12
def __init__(self, aws_conn_id: str = 'aws_default', **kwargs):
13
"""
14
Initialize SageMaker Hook.
15
16
Parameters:
17
- aws_conn_id: AWS connection ID
18
"""
19
20
def create_training_job(self, config: dict, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None) -> dict:
21
"""
22
Create SageMaker training job.
23
24
Parameters:
25
- config: Training job configuration
26
- wait_for_completion: Wait for job completion
27
- print_log: Print training logs
28
- check_interval: Status check interval in seconds
29
- max_ingestion_time: Maximum log ingestion time
30
31
Returns:
32
Training job details
33
"""
34
35
def create_tuning_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
36
"""
37
Create hyperparameter tuning job.
38
39
Parameters:
40
- config: Tuning job configuration
41
- wait_for_completion: Wait for job completion
42
- check_interval: Status check interval in seconds
43
44
Returns:
45
Tuning job details
46
"""
47
48
def create_model(self, config: dict) -> dict:
49
"""
50
Create SageMaker model.
51
52
Parameters:
53
- config: Model configuration
54
55
Returns:
56
Model details
57
"""
58
59
def create_endpoint_config(self, config: dict) -> dict:
60
"""
61
Create endpoint configuration.
62
63
Parameters:
64
- config: Endpoint configuration
65
66
Returns:
67
Endpoint config details
68
"""
69
70
def create_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
71
"""
72
Create SageMaker endpoint.
73
74
Parameters:
75
- config: Endpoint configuration
76
- wait_for_completion: Wait for endpoint to be in service
77
- check_interval: Status check interval in seconds
78
79
Returns:
80
Endpoint details
81
"""
82
83
def create_transform_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
84
"""
85
Create batch transform job.
86
87
Parameters:
88
- config: Transform job configuration
89
- wait_for_completion: Wait for job completion
90
- check_interval: Status check interval in seconds
91
92
Returns:
93
Transform job details
94
"""
95
96
def create_processing_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
97
"""
98
Create processing job.
99
100
Parameters:
101
- config: Processing job configuration
102
- wait_for_completion: Wait for job completion
103
- check_interval: Status check interval in seconds
104
105
Returns:
106
Processing job details
107
"""
108
109
def describe_training_job(self, name: str) -> dict:
110
"""
111
Get training job details.
112
113
Parameters:
114
- name: Training job name
115
116
Returns:
117
Training job description
118
"""
119
120
def describe_model(self, name: str) -> dict:
121
"""
122
Get model details.
123
124
Parameters:
125
- name: Model name
126
127
Returns:
128
Model description
129
"""
130
131
def describe_endpoint(self, name: str) -> dict:
132
"""
133
Get endpoint details.
134
135
Parameters:
136
- name: Endpoint name
137
138
Returns:
139
Endpoint description
140
"""
141
142
def delete_model(self, name: str) -> None:
143
"""
144
Delete SageMaker model.
145
146
Parameters:
147
- name: Model name
148
"""
149
150
def delete_endpoint_config(self, name: str) -> None:
151
"""
152
Delete endpoint configuration.
153
154
Parameters:
155
- name: Endpoint config name
156
"""
157
158
def delete_endpoint(self, name: str) -> None:
159
"""
160
Delete SageMaker endpoint.
161
162
Parameters:
163
- name: Endpoint name
164
"""
165
```
166
167
### SageMaker Operators
168
169
Task implementations for SageMaker ML operations.
170
171
```python { .api }
172
class SageMakerTrainingOperator(BaseOperator):
173
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None, **kwargs):
174
"""
175
Start SageMaker training job.
176
177
Parameters:
178
- config: Training job configuration
179
- aws_conn_id: AWS connection ID
180
- wait_for_completion: Wait for job completion
181
- print_log: Print training logs
182
- check_interval: Status check interval
183
- max_ingestion_time: Maximum log ingestion time
184
"""
185
186
class SageMakerTuningOperator(BaseOperator):
187
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
188
"""
189
Start hyperparameter tuning job.
190
191
Parameters:
192
- config: Tuning job configuration
193
- aws_conn_id: AWS connection ID
194
- wait_for_completion: Wait for job completion
195
- check_interval: Status check interval
196
"""
197
198
class SageMakerModelOperator(BaseOperator):
199
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
200
"""
201
Create SageMaker model.
202
203
Parameters:
204
- config: Model configuration
205
- aws_conn_id: AWS connection ID
206
"""
207
208
class SageMakerEndpointOperator(BaseOperator):
209
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
210
"""
211
Create SageMaker endpoint.
212
213
Parameters:
214
- config: Endpoint configuration
215
- aws_conn_id: AWS connection ID
216
- wait_for_completion: Wait for endpoint creation
217
- check_interval: Status check interval
218
"""
219
220
class SageMakerTransformOperator(BaseOperator):
221
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
222
"""
223
Start batch transform job.
224
225
Parameters:
226
- config: Transform job configuration
227
- aws_conn_id: AWS connection ID
228
- wait_for_completion: Wait for job completion
229
- check_interval: Status check interval
230
"""
231
232
class SageMakerProcessingOperator(BaseOperator):
233
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
234
"""
235
Start processing job.
236
237
Parameters:
238
- config: Processing job configuration
239
- aws_conn_id: AWS connection ID
240
- wait_for_completion: Wait for job completion
241
- check_interval: Status check interval
242
"""
243
244
class SageMakerDeleteModelOperator(BaseOperator):
245
def __init__(self, model_name: str, aws_conn_id: str = 'aws_default', **kwargs):
246
"""
247
Delete SageMaker model.
248
249
Parameters:
250
- model_name: Model name to delete
251
- aws_conn_id: AWS connection ID
252
"""
253
```
254
255
### SageMaker Sensors
256
257
Monitoring tasks for SageMaker job and endpoint states.
258
259
```python { .api }
260
class SageMakerTrainingSensor(BaseSensorOperator):
261
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
262
"""
263
Wait for SageMaker training job completion.
264
265
Parameters:
266
- job_name: Training job name
267
- aws_conn_id: AWS connection ID
268
"""
269
270
class SageMakerTuningSensor(BaseSensorOperator):
271
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
272
"""
273
Wait for hyperparameter tuning job completion.
274
275
Parameters:
276
- job_name: Tuning job name
277
- aws_conn_id: AWS connection ID
278
"""
279
280
class SageMakerTransformSensor(BaseSensorOperator):
281
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
282
"""
283
Wait for batch transform job completion.
284
285
Parameters:
286
- job_name: Transform job name
287
- aws_conn_id: AWS connection ID
288
"""
289
290
class SageMakerEndpointSensor(BaseSensorOperator):
291
def __init__(self, endpoint_name: str, aws_conn_id: str = 'aws_default', **kwargs):
292
"""
293
Wait for SageMaker endpoint to be in service.
294
295
Parameters:
296
- endpoint_name: Endpoint name
297
- aws_conn_id: AWS connection ID
298
"""
299
```
300
301
### SageMaker Triggers
302
303
Asynchronous triggers for SageMaker operations.
304
305
```python { .api }
306
class SageMakerTrigger(BaseTrigger):
307
def __init__(self, job_name: str, job_type: str, aws_conn_id: str = 'aws_default', poll_interval: int = 30, **kwargs):
308
"""
309
Asynchronous trigger for SageMaker job monitoring.
310
311
Parameters:
312
- job_name: Job name to monitor
313
- job_type: Type of job ('training', 'tuning', 'transform', 'processing')
314
- aws_conn_id: AWS connection ID
315
- poll_interval: Polling interval in seconds
316
"""
317
```
318
319
## Usage Examples
320
321
### End-to-End ML Pipeline
322
323
```python
324
from airflow import DAG
325
from airflow.providers.amazon.aws.operators.sagemaker import (
326
SageMakerTrainingOperator,
327
SageMakerModelOperator,
328
SageMakerEndpointOperator
329
)
330
331
dag = DAG('ml_pipeline', start_date=datetime(2023, 1, 1))
332
333
# Training job configuration
334
training_config = {
335
'TrainingJobName': 'customer-churn-model-{{ ds }}',
336
'RoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
337
'AlgorithmSpecification': {
338
'TrainingImage': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
339
'TrainingInputMode': 'File'
340
},
341
'InputDataConfig': [
342
{
343
'ChannelName': 'training',
344
'DataSource': {
345
'S3DataSource': {
346
'S3DataType': 'S3Prefix',
347
'S3Uri': 's3://ml-training-data/customer-churn/train/',
348
'S3DataDistributionType': 'FullyReplicated'
349
}
350
},
351
'ContentType': 'text/csv',
352
'CompressionType': 'None'
353
},
354
{
355
'ChannelName': 'validation',
356
'DataSource': {
357
'S3DataSource': {
358
'S3DataType': 'S3Prefix',
359
'S3Uri': 's3://ml-training-data/customer-churn/validation/',
360
'S3DataDistributionType': 'FullyReplicated'
361
}
362
},
363
'ContentType': 'text/csv',
364
'CompressionType': 'None'
365
}
366
],
367
'OutputDataConfig': {
368
'S3OutputPath': 's3://ml-model-artifacts/customer-churn/'
369
},
370
'ResourceConfig': {
371
'InstanceType': 'ml.m5.large',
372
'InstanceCount': 1,
373
'VolumeSizeInGB': 30
374
},
375
'StoppingCondition': {
376
'MaxRuntimeInSeconds': 3600
377
},
378
'HyperParameters': {
379
'max_depth': '5',
380
'eta': '0.2',
381
'gamma': '4',
382
'min_child_weight': '6',
383
'subsample': '0.8',
384
'silent': '0',
385
'objective': 'binary:logistic',
386
'num_round': '100'
387
}
388
}
389
390
# Train model
391
train_model = SageMakerTrainingOperator(
392
task_id='train_churn_model',
393
config=training_config,
394
wait_for_completion=True,
395
print_log=True,
396
dag=dag
397
)
398
399
# Create model
400
model_config = {
401
'ModelName': 'customer-churn-model-{{ ds }}',
402
'ExecutionRoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
403
'PrimaryContainer': {
404
'Image': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
405
'ModelDataUrl': 's3://ml-model-artifacts/customer-churn/customer-churn-model-{{ ds }}/output/model.tar.gz',
406
'Environment': {
407
'SAGEMAKER_PROGRAM': 'inference.py',
408
'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/code'
409
}
410
}
411
}
412
413
create_model = SageMakerModelOperator(
414
task_id='create_model',
415
config=model_config,
416
dag=dag
417
)
418
419
# Deploy endpoint
420
endpoint_config = {
421
'EndpointName': 'customer-churn-endpoint',
422
'EndpointConfigName': 'customer-churn-config-{{ ds }}',
423
'ProductionVariants': [
424
{
425
'VariantName': 'primary',
426
'ModelName': 'customer-churn-model-{{ ds }}',
427
'InitialInstanceCount': 1,
428
'InstanceType': 'ml.t2.medium',
429
'InitialVariantWeight': 1
430
}
431
]
432
}
433
434
deploy_endpoint = SageMakerEndpointOperator(
435
task_id='deploy_endpoint',
436
config=endpoint_config,
437
wait_for_completion=True,
438
dag=dag
439
)
440
441
train_model >> create_model >> deploy_endpoint
442
```
443
444
### Hyperparameter Tuning
445
446
```python
447
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
448
449
# Hyperparameter tuning configuration
450
tuning_config = {
451
'HyperParameterTuningJobName': 'xgboost-tuning-{{ ds }}',
452
'HyperParameterTuningJobConfig': {
453
'Strategy': 'Bayesian',
454
'HyperParameterTuningJobObjective': {
455
'Type': 'Maximize',
456
'MetricName': 'validation:auc'
457
},
458
'ResourceLimits': {
459
'MaxNumberOfTrainingJobs': 20,
460
'MaxParallelTrainingJobs': 3
461
},
462
'ParameterRanges': {
463
'IntegerParameterRanges': [
464
{
465
'Name': 'max_depth',
466
'MinValue': '1',
467
'MaxValue': '10'
468
},
469
{
470
'Name': 'num_round',
471
'MinValue': '50',
472
'MaxValue': '200'
473
}
474
],
475
'ContinuousParameterRanges': [
476
{
477
'Name': 'eta',
478
'MinValue': '0.1',
479
'MaxValue': '0.5'
480
},
481
{
482
'Name': 'subsample',
483
'MinValue': '0.5',
484
'MaxValue': '1.0'
485
}
486
]
487
}
488
},
489
'TrainingJobDefinition': training_config
490
}
491
492
tune_hyperparameters = SageMakerTuningOperator(
493
task_id='tune_hyperparameters',
494
config=tuning_config,
495
wait_for_completion=True,
496
dag=dag
497
)
498
```
499
500
### Batch Inference
501
502
```python
503
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
504
505
# Batch transform configuration
506
transform_config = {
507
'TransformJobName': 'batch-inference-{{ ds }}',
508
'ModelName': 'customer-churn-model-{{ ds }}',
509
'TransformInput': {
510
'DataSource': {
511
'S3DataSource': {
512
'S3DataType': 'S3Prefix',
513
'S3Uri': 's3://ml-inference-data/batch/{{ ds }}/'
514
}
515
},
516
'ContentType': 'text/csv',
517
'SplitType': 'Line'
518
},
519
'TransformOutput': {
520
'S3OutputPath': 's3://ml-inference-results/{{ ds }}/',
521
'Accept': 'text/csv'
522
},
523
'TransformResources': {
524
'InstanceType': 'ml.m5.large',
525
'InstanceCount': 1
526
}
527
}
528
529
batch_inference = SageMakerTransformOperator(
530
task_id='batch_inference',
531
config=transform_config,
532
wait_for_completion=True,
533
dag=dag
534
)
535
```
536
537
## Types
538
539
```python { .api }
540
# SageMaker job states
541
class SageMakerJobState:
542
IN_PROGRESS = 'InProgress'
543
COMPLETED = 'Completed'
544
FAILED = 'Failed'
545
STOPPING = 'Stopping'
546
STOPPED = 'Stopped'
547
548
# Instance types
549
class SageMakerInstanceType:
550
ML_T2_MEDIUM = 'ml.t2.medium'
551
ML_T2_LARGE = 'ml.t2.large'
552
ML_M5_LARGE = 'ml.m5.large'
553
ML_M5_XLARGE = 'ml.m5.xlarge'
554
ML_C5_LARGE = 'ml.c5.large'
555
ML_C5_XLARGE = 'ml.c5.xlarge'
556
ML_P3_2XLARGE = 'ml.p3.2xlarge'
557
ML_P3_8XLARGE = 'ml.p3.8xlarge'
558
559
# Training job configuration
560
class TrainingJobConfig:
561
training_job_name: str
562
role_arn: str
563
algorithm_specification: dict
564
input_data_config: list
565
output_data_config: dict
566
resource_config: dict
567
stopping_condition: dict
568
hyper_parameters: dict = None
569
vpc_config: dict = None
570
tags: list = None
571
enable_network_isolation: bool = False
572
enable_inter_container_traffic_encryption: bool = False
573
enable_managed_spot_training: bool = False
574
checkpoint_config: dict = None
575
debug_hook_config: dict = None
576
debug_rule_configurations: list = None
577
tensor_board_output_config: dict = None
578
experiment_config: dict = None
579
profiler_config: dict = None
580
profiler_rule_configurations: list = None
581
environment: dict = None
582
retry_strategy: dict = None
583
584
# Model configuration
585
class ModelConfig:
586
model_name: str
587
execution_role_arn: str
588
primary_container: dict = None
589
containers: list = None
590
inference_execution_config: dict = None
591
tags: list = None
592
vpc_config: dict = None
593
enable_network_isolation: bool = False
594
595
# Endpoint configuration
596
class EndpointConfig:
597
endpoint_name: str
598
endpoint_config_name: str
599
production_variants: list
600
data_capture_config: dict = None
601
tags: list = None
602
kms_key_id: str = None
603
async_inference_config: dict = None
604
explainer_config: dict = None
605
shadow_production_variants: list = None
606
```