CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-apache-airflow-providers-amazon

Apache Airflow provider package that provides comprehensive AWS service integrations for orchestrating cloud workflows and data pipelines

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

sagemaker-ml.mddocs/

SageMaker Machine Learning

Amazon SageMaker integration for end-to-end machine learning workflows including model training, tuning, deployment, and batch inference. Provides comprehensive MLOps capabilities for building, training, and deploying ML models at scale.

Capabilities

SageMaker Hook

Core SageMaker client providing ML lifecycle management functionality.

class SageMakerHook(AwsBaseHook):
    def __init__(self, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Initialize SageMaker Hook.
        
        Parameters:
        - aws_conn_id: AWS connection ID
        """

    def create_training_job(self, config: dict, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None) -> dict:
        """
        Create SageMaker training job.
        
        Parameters:
        - config: Training job configuration
        - wait_for_completion: Wait for job completion
        - print_log: Print training logs
        - check_interval: Status check interval in seconds
        - max_ingestion_time: Maximum log ingestion time
        
        Returns:
        Training job details
        """

    def create_tuning_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
        """
        Create hyperparameter tuning job.
        
        Parameters:
        - config: Tuning job configuration
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval in seconds
        
        Returns:
        Tuning job details
        """

    def create_model(self, config: dict) -> dict:
        """
        Create SageMaker model.
        
        Parameters:
        - config: Model configuration
        
        Returns:
        Model details
        """

    def create_endpoint_config(self, config: dict) -> dict:
        """
        Create endpoint configuration.
        
        Parameters:
        - config: Endpoint configuration
        
        Returns:
        Endpoint config details
        """

    def create_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
        """
        Create SageMaker endpoint.
        
        Parameters:
        - config: Endpoint configuration
        - wait_for_completion: Wait for endpoint to be in service
        - check_interval: Status check interval in seconds
        
        Returns:
        Endpoint details
        """

    def create_transform_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
        """
        Create batch transform job.
        
        Parameters:
        - config: Transform job configuration
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval in seconds
        
        Returns:
        Transform job details
        """

    def create_processing_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:
        """
        Create processing job.
        
        Parameters:
        - config: Processing job configuration
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval in seconds
        
        Returns:
        Processing job details
        """

    def describe_training_job(self, name: str) -> dict:
        """
        Get training job details.
        
        Parameters:
        - name: Training job name
        
        Returns:
        Training job description
        """

    def describe_model(self, name: str) -> dict:
        """
        Get model details.
        
        Parameters:
        - name: Model name
        
        Returns:
        Model description
        """

    def describe_endpoint(self, name: str) -> dict:
        """
        Get endpoint details.
        
        Parameters:
        - name: Endpoint name
        
        Returns:
        Endpoint description
        """

    def delete_model(self, name: str) -> None:
        """
        Delete SageMaker model.
        
        Parameters:
        - name: Model name
        """

    def delete_endpoint_config(self, name: str) -> None:
        """
        Delete endpoint configuration.
        
        Parameters:
        - name: Endpoint config name
        """

    def delete_endpoint(self, name: str) -> None:
        """
        Delete SageMaker endpoint.
        
        Parameters:
        - name: Endpoint name
        """

SageMaker Operators

Task implementations for SageMaker ML operations.

class SageMakerTrainingOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None, **kwargs):
        """
        Start SageMaker training job.
        
        Parameters:
        - config: Training job configuration
        - aws_conn_id: AWS connection ID
        - wait_for_completion: Wait for job completion
        - print_log: Print training logs
        - check_interval: Status check interval
        - max_ingestion_time: Maximum log ingestion time
        """

class SageMakerTuningOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
        """
        Start hyperparameter tuning job.
        
        Parameters:
        - config: Tuning job configuration
        - aws_conn_id: AWS connection ID
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval
        """

class SageMakerModelOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Create SageMaker model.
        
        Parameters:
        - config: Model configuration
        - aws_conn_id: AWS connection ID
        """

class SageMakerEndpointOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
        """
        Create SageMaker endpoint.
        
        Parameters:
        - config: Endpoint configuration
        - aws_conn_id: AWS connection ID
        - wait_for_completion: Wait for endpoint creation
        - check_interval: Status check interval
        """

class SageMakerTransformOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
        """
        Start batch transform job.
        
        Parameters:
        - config: Transform job configuration
        - aws_conn_id: AWS connection ID
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval
        """

class SageMakerProcessingOperator(BaseOperator):
    def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):
        """
        Start processing job.
        
        Parameters:
        - config: Processing job configuration
        - aws_conn_id: AWS connection ID
        - wait_for_completion: Wait for job completion
        - check_interval: Status check interval
        """

class SageMakerDeleteModelOperator(BaseOperator):
    def __init__(self, model_name: str, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Delete SageMaker model.
        
        Parameters:
        - model_name: Model name to delete
        - aws_conn_id: AWS connection ID
        """

SageMaker Sensors

Monitoring tasks for SageMaker job and endpoint states.

class SageMakerTrainingSensor(BaseSensorOperator):
    def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Wait for SageMaker training job completion.
        
        Parameters:
        - job_name: Training job name
        - aws_conn_id: AWS connection ID
        """

class SageMakerTuningSensor(BaseSensorOperator):
    def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Wait for hyperparameter tuning job completion.
        
        Parameters:
        - job_name: Tuning job name
        - aws_conn_id: AWS connection ID
        """

class SageMakerTransformSensor(BaseSensorOperator):
    def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Wait for batch transform job completion.
        
        Parameters:
        - job_name: Transform job name
        - aws_conn_id: AWS connection ID
        """

class SageMakerEndpointSensor(BaseSensorOperator):
    def __init__(self, endpoint_name: str, aws_conn_id: str = 'aws_default', **kwargs):
        """
        Wait for SageMaker endpoint to be in service.
        
        Parameters:
        - endpoint_name: Endpoint name
        - aws_conn_id: AWS connection ID
        """

SageMaker Triggers

Asynchronous triggers for SageMaker operations.

class SageMakerTrigger(BaseTrigger):
    def __init__(self, job_name: str, job_type: str, aws_conn_id: str = 'aws_default', poll_interval: int = 30, **kwargs):
        """
        Asynchronous trigger for SageMaker job monitoring.
        
        Parameters:
        - job_name: Job name to monitor
        - job_type: Type of job ('training', 'tuning', 'transform', 'processing')
        - aws_conn_id: AWS connection ID
        - poll_interval: Polling interval in seconds
        """

Usage Examples

End-to-End ML Pipeline

from airflow import DAG
from airflow.providers.amazon.aws.operators.sagemaker import (
    SageMakerTrainingOperator,
    SageMakerModelOperator,
    SageMakerEndpointOperator
)

dag = DAG('ml_pipeline', start_date=datetime(2023, 1, 1))

# Training job configuration
training_config = {
    'TrainingJobName': 'customer-churn-model-{{ ds }}',
    'RoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
    'AlgorithmSpecification': {
        'TrainingImage': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
        'TrainingInputMode': 'File'
    },
    'InputDataConfig': [
        {
            'ChannelName': 'training',
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'S3Prefix',
                    'S3Uri': 's3://ml-training-data/customer-churn/train/',
                    'S3DataDistributionType': 'FullyReplicated'
                }
            },
            'ContentType': 'text/csv',
            'CompressionType': 'None'
        },
        {
            'ChannelName': 'validation',
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'S3Prefix',
                    'S3Uri': 's3://ml-training-data/customer-churn/validation/',
                    'S3DataDistributionType': 'FullyReplicated'
                }
            },
            'ContentType': 'text/csv',
            'CompressionType': 'None'
        }
    ],
    'OutputDataConfig': {
        'S3OutputPath': 's3://ml-model-artifacts/customer-churn/'
    },
    'ResourceConfig': {
        'InstanceType': 'ml.m5.large',
        'InstanceCount': 1,
        'VolumeSizeInGB': 30
    },
    'StoppingCondition': {
        'MaxRuntimeInSeconds': 3600
    },
    'HyperParameters': {
        'max_depth': '5',
        'eta': '0.2',
        'gamma': '4',
        'min_child_weight': '6',
        'subsample': '0.8',
        'silent': '0',
        'objective': 'binary:logistic',
        'num_round': '100'
    }
}

# Train model
train_model = SageMakerTrainingOperator(
    task_id='train_churn_model',
    config=training_config,
    wait_for_completion=True,
    print_log=True,
    dag=dag
)

# Create model
model_config = {
    'ModelName': 'customer-churn-model-{{ ds }}',
    'ExecutionRoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',
    'PrimaryContainer': {
        'Image': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',
        'ModelDataUrl': 's3://ml-model-artifacts/customer-churn/customer-churn-model-{{ ds }}/output/model.tar.gz',
        'Environment': {
            'SAGEMAKER_PROGRAM': 'inference.py',
            'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/code'
        }
    }
}

create_model = SageMakerModelOperator(
    task_id='create_model',
    config=model_config,
    dag=dag
)

# Deploy endpoint
endpoint_config = {
    'EndpointName': 'customer-churn-endpoint',
    'EndpointConfigName': 'customer-churn-config-{{ ds }}',
    'ProductionVariants': [
        {
            'VariantName': 'primary',
            'ModelName': 'customer-churn-model-{{ ds }}',
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.t2.medium',
            'InitialVariantWeight': 1
        }
    ]
}

deploy_endpoint = SageMakerEndpointOperator(
    task_id='deploy_endpoint',
    config=endpoint_config,
    wait_for_completion=True,
    dag=dag
)

train_model >> create_model >> deploy_endpoint

Hyperparameter Tuning

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator

# Hyperparameter tuning configuration
tuning_config = {
    'HyperParameterTuningJobName': 'xgboost-tuning-{{ ds }}',
    'HyperParameterTuningJobConfig': {
        'Strategy': 'Bayesian',
        'HyperParameterTuningJobObjective': {
            'Type': 'Maximize',
            'MetricName': 'validation:auc'
        },
        'ResourceLimits': {
            'MaxNumberOfTrainingJobs': 20,
            'MaxParallelTrainingJobs': 3
        },
        'ParameterRanges': {
            'IntegerParameterRanges': [
                {
                    'Name': 'max_depth',
                    'MinValue': '1',
                    'MaxValue': '10'
                },
                {
                    'Name': 'num_round',
                    'MinValue': '50',
                    'MaxValue': '200'
                }
            ],
            'ContinuousParameterRanges': [
                {
                    'Name': 'eta',
                    'MinValue': '0.1',
                    'MaxValue': '0.5'
                },
                {
                    'Name': 'subsample',
                    'MinValue': '0.5',
                    'MaxValue': '1.0'
                }
            ]
        }
    },
    'TrainingJobDefinition': training_config
}

tune_hyperparameters = SageMakerTuningOperator(
    task_id='tune_hyperparameters',
    config=tuning_config,
    wait_for_completion=True,
    dag=dag
)

Batch Inference

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator

# Batch transform configuration
transform_config = {
    'TransformJobName': 'batch-inference-{{ ds }}',
    'ModelName': 'customer-churn-model-{{ ds }}',
    'TransformInput': {
        'DataSource': {
            'S3DataSource': {
                'S3DataType': 'S3Prefix',
                'S3Uri': 's3://ml-inference-data/batch/{{ ds }}/'
            }
        },
        'ContentType': 'text/csv',
        'SplitType': 'Line'
    },
    'TransformOutput': {
        'S3OutputPath': 's3://ml-inference-results/{{ ds }}/',
        'Accept': 'text/csv'
    },
    'TransformResources': {
        'InstanceType': 'ml.m5.large',
        'InstanceCount': 1
    }
}

batch_inference = SageMakerTransformOperator(
    task_id='batch_inference',
    config=transform_config,
    wait_for_completion=True,
    dag=dag
)

Types

# SageMaker job states
class SageMakerJobState:
    IN_PROGRESS = 'InProgress'
    COMPLETED = 'Completed'
    FAILED = 'Failed'
    STOPPING = 'Stopping'
    STOPPED = 'Stopped'

# Instance types
class SageMakerInstanceType:
    ML_T2_MEDIUM = 'ml.t2.medium'
    ML_T2_LARGE = 'ml.t2.large'
    ML_M5_LARGE = 'ml.m5.large'
    ML_M5_XLARGE = 'ml.m5.xlarge'
    ML_C5_LARGE = 'ml.c5.large'
    ML_C5_XLARGE = 'ml.c5.xlarge'
    ML_P3_2XLARGE = 'ml.p3.2xlarge'
    ML_P3_8XLARGE = 'ml.p3.8xlarge'

# Training job configuration
class TrainingJobConfig:
    training_job_name: str
    role_arn: str
    algorithm_specification: dict
    input_data_config: list
    output_data_config: dict
    resource_config: dict
    stopping_condition: dict
    hyper_parameters: dict = None
    vpc_config: dict = None
    tags: list = None
    enable_network_isolation: bool = False
    enable_inter_container_traffic_encryption: bool = False
    enable_managed_spot_training: bool = False
    checkpoint_config: dict = None
    debug_hook_config: dict = None
    debug_rule_configurations: list = None
    tensor_board_output_config: dict = None
    experiment_config: dict = None
    profiler_config: dict = None
    profiler_rule_configurations: list = None
    environment: dict = None
    retry_strategy: dict = None

# Model configuration
class ModelConfig:
    model_name: str
    execution_role_arn: str
    primary_container: dict = None
    containers: list = None
    inference_execution_config: dict = None
    tags: list = None
    vpc_config: dict = None
    enable_network_isolation: bool = False

# Endpoint configuration
class EndpointConfig:
    endpoint_name: str
    endpoint_config_name: str
    production_variants: list
    data_capture_config: dict = None
    tags: list = None
    kms_key_id: str = None
    async_inference_config: dict = None
    explainer_config: dict = None
    shadow_production_variants: list = None

Install with Tessl CLI

npx tessl i tessl/pypi-apache-airflow-providers-amazon

docs

athena-analytics.md

authentication.md

batch-processing.md

data-transfers.md

dms-migration.md

dynamodb-nosql.md

ecs-containers.md

eks-kubernetes.md

emr-clusters.md

glue-processing.md

index.md

lambda-functions.md

messaging-sns-sqs.md

rds-databases.md

redshift-warehouse.md

s3-storage.md

sagemaker-ml.md

tile.json