Apache Airflow provider package that provides comprehensive AWS service integrations for orchestrating cloud workflows and data pipelines
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Amazon SageMaker integration for end-to-end machine learning workflows including model training, tuning, deployment, and batch inference. Provides comprehensive MLOps capabilities for building, training, and deploying ML models at scale.
Core SageMaker client providing ML lifecycle management functionality.
class SageMakerHook(AwsBaseHook):
def __init__(self, aws_conn_id: str = 'aws_default', **kwargs):
"""
Initialize SageMaker Hook.
Parameters:
- aws_conn_id: AWS connection ID
"""
def create_training_job(self, config: dict, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None) -> dict:
"""
Create SageMaker training job.
Parameters:
- config: Training job configuration
- wait_for_completion: Wait for job completion
- print_log: Print training logs
- check_interval: Status check interval in seconds
- max_ingestion_time: Maximum log ingestion time
Returns:
Training job details
"""
def create_tuning_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
"""
Create hyperparameter tuning job.
Parameters:
- config: Tuning job configuration
- wait_for_completion: Wait for job completion
- check_interval: Status check interval in seconds
Returns:
Tuning job details
"""
def create_model(self, config: dict) -> dict:
"""
Create SageMaker model.
Parameters:
- config: Model configuration
Returns:
Model details
"""
def create_endpoint_config(self, config: dict) -> dict:
"""
Create endpoint configuration.
Parameters:
- config: Endpoint configuration
Returns:
Endpoint config details
"""
def create_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
"""
Create SageMaker endpoint.
Parameters:
- config: Endpoint configuration
- wait_for_completion: Wait for endpoint to be in service
- check_interval: Status check interval in seconds
Returns:
Endpoint details
"""
def create_transform_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
"""
Create batch transform job.
Parameters:
- config: Transform job configuration
- wait_for_completion: Wait for job completion
- check_interval: Status check interval in seconds
Returns:
Transform job details
"""
def create_processing_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
"""
Create processing job.
Parameters:
- config: Processing job configuration
- wait_for_completion: Wait for job completion
- check_interval: Status check interval in seconds
Returns:
Processing job details
"""
def describe_training_job(self, name: str) -> dict:
"""
Get training job details.
Parameters:
- name: Training job name
Returns:
Training job description
"""
def describe_model(self, name: str) -> dict:
"""
Get model details.
Parameters:
- name: Model name
Returns:
Model description
"""
def describe_endpoint(self, name: str) -> dict:
"""
Get endpoint details.
Parameters:
- name: Endpoint name
Returns:
Endpoint description
"""
def delete_model(self, name: str) -> None:
"""
Delete SageMaker model.
Parameters:
- name: Model name
"""
def delete_endpoint_config(self, name: str) -> None:
"""
Delete endpoint configuration.
Parameters:
- name: Endpoint config name
"""
def delete_endpoint(self, name: str) -> None:
"""
Delete SageMaker endpoint.
Parameters:
- name: Endpoint name
"""Task implementations for SageMaker ML operations.
class SageMakerTrainingOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None, **kwargs):
"""
Start SageMaker training job.
Parameters:
- config: Training job configuration
- aws_conn_id: AWS connection ID
- wait_for_completion: Wait for job completion
- print_log: Print training logs
- check_interval: Status check interval
- max_ingestion_time: Maximum log ingestion time
"""
class SageMakerTuningOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
"""
Start hyperparameter tuning job.
Parameters:
- config: Tuning job configuration
- aws_conn_id: AWS connection ID
- wait_for_completion: Wait for job completion
- check_interval: Status check interval
"""
class SageMakerModelOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
"""
Create SageMaker model.
Parameters:
- config: Model configuration
- aws_conn_id: AWS connection ID
"""
class SageMakerEndpointOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
"""
Create SageMaker endpoint.
Parameters:
- config: Endpoint configuration
- aws_conn_id: AWS connection ID
- wait_for_completion: Wait for endpoint creation
- check_interval: Status check interval
"""
class SageMakerTransformOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
"""
Start batch transform job.
Parameters:
- config: Transform job configuration
- aws_conn_id: AWS connection ID
- wait_for_completion: Wait for job completion
- check_interval: Status check interval
"""
class SageMakerProcessingOperator(BaseOperator):
def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
"""
Start processing job.
Parameters:
- config: Processing job configuration
- aws_conn_id: AWS connection ID
- wait_for_completion: Wait for job completion
- check_interval: Status check interval
"""
class SageMakerDeleteModelOperator(BaseOperator):
def __init__(self, model_name: str, aws_conn_id: str = 'aws_default', **kwargs):
"""
Delete SageMaker model.
Parameters:
- model_name: Model name to delete
- aws_conn_id: AWS connection ID
"""Monitoring tasks for SageMaker job and endpoint states.
class SageMakerTrainingSensor(BaseSensorOperator):
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
"""
Wait for SageMaker training job completion.
Parameters:
- job_name: Training job name
- aws_conn_id: AWS connection ID
"""
class SageMakerTuningSensor(BaseSensorOperator):
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
"""
Wait for hyperparameter tuning job completion.
Parameters:
- job_name: Tuning job name
- aws_conn_id: AWS connection ID
"""
class SageMakerTransformSensor(BaseSensorOperator):
def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
"""
Wait for batch transform job completion.
Parameters:
- job_name: Transform job name
- aws_conn_id: AWS connection ID
"""
class SageMakerEndpointSensor(BaseSensorOperator):
def __init__(self, endpoint_name: str, aws_conn_id: str = 'aws_default', **kwargs):
"""
Wait for SageMaker endpoint to be in service.
Parameters:
- endpoint_name: Endpoint name
- aws_conn_id: AWS connection ID
"""Asynchronous triggers for SageMaker operations.
class SageMakerTrigger(BaseTrigger):
def __init__(self, job_name: str, job_type: str, aws_conn_id: str = 'aws_default', poll_interval: int = 30, **kwargs):
"""
Asynchronous trigger for SageMaker job monitoring.
Parameters:
- job_name: Job name to monitor
- job_type: Type of job ('training', 'tuning', 'transform', 'processing')
- aws_conn_id: AWS connection ID
- poll_interval: Polling interval in seconds
"""from airflow import DAG
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerTrainingOperator,
SageMakerModelOperator,
SageMakerEndpointOperator
)
dag = DAG('ml_pipeline', start_date=datetime(2023, 1, 1))
# Training job configuration
training_config = {
'TrainingJobName': 'customer-churn-model-{{ ds }}',
'RoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
'AlgorithmSpecification': {
'TrainingImage': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
'TrainingInputMode': 'File'
},
'InputDataConfig': [
{
'ChannelName': 'training',
'DataSource': {
'S3DataSource': {
'S3DataType': 'S3Prefix',
'S3Uri': 's3://ml-training-data/customer-churn/train/',
'S3DataDistributionType': 'FullyReplicated'
}
},
'ContentType': 'text/csv',
'CompressionType': 'None'
},
{
'ChannelName': 'validation',
'DataSource': {
'S3DataSource': {
'S3DataType': 'S3Prefix',
'S3Uri': 's3://ml-training-data/customer-churn/validation/',
'S3DataDistributionType': 'FullyReplicated'
}
},
'ContentType': 'text/csv',
'CompressionType': 'None'
}
],
'OutputDataConfig': {
'S3OutputPath': 's3://ml-model-artifacts/customer-churn/'
},
'ResourceConfig': {
'InstanceType': 'ml.m5.large',
'InstanceCount': 1,
'VolumeSizeInGB': 30
},
'StoppingCondition': {
'MaxRuntimeInSeconds': 3600
},
'HyperParameters': {
'max_depth': '5',
'eta': '0.2',
'gamma': '4',
'min_child_weight': '6',
'subsample': '0.8',
'silent': '0',
'objective': 'binary:logistic',
'num_round': '100'
}
}
# Train model
train_model = SageMakerTrainingOperator(
task_id='train_churn_model',
config=training_config,
wait_for_completion=True,
print_log=True,
dag=dag
)
# Create model
model_config = {
'ModelName': 'customer-churn-model-{{ ds }}',
'ExecutionRoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
'PrimaryContainer': {
'Image': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
'ModelDataUrl': 's3://ml-model-artifacts/customer-churn/customer-churn-model-{{ ds }}/output/model.tar.gz',
'Environment': {
'SAGEMAKER_PROGRAM': 'inference.py',
'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/code'
}
}
}
create_model = SageMakerModelOperator(
task_id='create_model',
config=model_config,
dag=dag
)
# Deploy endpoint
endpoint_config = {
'EndpointName': 'customer-churn-endpoint',
'EndpointConfigName': 'customer-churn-config-{{ ds }}',
'ProductionVariants': [
{
'VariantName': 'primary',
'ModelName': 'customer-churn-model-{{ ds }}',
'InitialInstanceCount': 1,
'InstanceType': 'ml.t2.medium',
'InitialVariantWeight': 1
}
]
}
deploy_endpoint = SageMakerEndpointOperator(
task_id='deploy_endpoint',
config=endpoint_config,
wait_for_completion=True,
dag=dag
)
train_model >> create_model >> deploy_endpointfrom airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
# Hyperparameter tuning configuration
tuning_config = {
'HyperParameterTuningJobName': 'xgboost-tuning-{{ ds }}',
'HyperParameterTuningJobConfig': {
'Strategy': 'Bayesian',
'HyperParameterTuningJobObjective': {
'Type': 'Maximize',
'MetricName': 'validation:auc'
},
'ResourceLimits': {
'MaxNumberOfTrainingJobs': 20,
'MaxParallelTrainingJobs': 3
},
'ParameterRanges': {
'IntegerParameterRanges': [
{
'Name': 'max_depth',
'MinValue': '1',
'MaxValue': '10'
},
{
'Name': 'num_round',
'MinValue': '50',
'MaxValue': '200'
}
],
'ContinuousParameterRanges': [
{
'Name': 'eta',
'MinValue': '0.1',
'MaxValue': '0.5'
},
{
'Name': 'subsample',
'MinValue': '0.5',
'MaxValue': '1.0'
}
]
}
},
'TrainingJobDefinition': training_config
}
tune_hyperparameters = SageMakerTuningOperator(
task_id='tune_hyperparameters',
config=tuning_config,
wait_for_completion=True,
dag=dag
)from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
# Batch transform configuration
transform_config = {
'TransformJobName': 'batch-inference-{{ ds }}',
'ModelName': 'customer-churn-model-{{ ds }}',
'TransformInput': {
'DataSource': {
'S3DataSource': {
'S3DataType': 'S3Prefix',
'S3Uri': 's3://ml-inference-data/batch/{{ ds }}/'
}
},
'ContentType': 'text/csv',
'SplitType': 'Line'
},
'TransformOutput': {
'S3OutputPath': 's3://ml-inference-results/{{ ds }}/',
'Accept': 'text/csv'
},
'TransformResources': {
'InstanceType': 'ml.m5.large',
'InstanceCount': 1
}
}
batch_inference = SageMakerTransformOperator(
task_id='batch_inference',
config=transform_config,
wait_for_completion=True,
dag=dag
)# SageMaker job states
class SageMakerJobState:
IN_PROGRESS = 'InProgress'
COMPLETED = 'Completed'
FAILED = 'Failed'
STOPPING = 'Stopping'
STOPPED = 'Stopped'
# Instance types
class SageMakerInstanceType:
ML_T2_MEDIUM = 'ml.t2.medium'
ML_T2_LARGE = 'ml.t2.large'
ML_M5_LARGE = 'ml.m5.large'
ML_M5_XLARGE = 'ml.m5.xlarge'
ML_C5_LARGE = 'ml.c5.large'
ML_C5_XLARGE = 'ml.c5.xlarge'
ML_P3_2XLARGE = 'ml.p3.2xlarge'
ML_P3_8XLARGE = 'ml.p3.8xlarge'
# Training job configuration
class TrainingJobConfig:
training_job_name: str
role_arn: str
algorithm_specification: dict
input_data_config: list
output_data_config: dict
resource_config: dict
stopping_condition: dict
hyper_parameters: dict = None
vpc_config: dict = None
tags: list = None
enable_network_isolation: bool = False
enable_inter_container_traffic_encryption: bool = False
enable_managed_spot_training: bool = False
checkpoint_config: dict = None
debug_hook_config: dict = None
debug_rule_configurations: list = None
tensor_board_output_config: dict = None
experiment_config: dict = None
profiler_config: dict = None
profiler_rule_configurations: list = None
environment: dict = None
retry_strategy: dict = None
# Model configuration
class ModelConfig:
model_name: str
execution_role_arn: str
primary_container: dict = None
containers: list = None
inference_execution_config: dict = None
tags: list = None
vpc_config: dict = None
enable_network_isolation: bool = False
# Endpoint configuration
class EndpointConfig:
endpoint_name: str
endpoint_config_name: str
production_variants: list
data_capture_config: dict = None
tags: list = None
kms_key_id: str = None
async_inference_config: dict = None
explainer_config: dict = None
shadow_production_variants: list = NoneInstall with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-amazon