or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

athena-analytics.mdauthentication.mdbatch-processing.mddata-transfers.mddms-migration.mddynamodb-nosql.mdecs-containers.mdeks-kubernetes.mdemr-clusters.mdglue-processing.mdindex.mdlambda-functions.mdmessaging-sns-sqs.mdrds-databases.mdredshift-warehouse.mds3-storage.mdsagemaker-ml.md

sagemaker-ml.mddocs/

0

# SageMaker Machine Learning

1

2

Amazon SageMaker integration for end-to-end machine learning workflows including model training, tuning, deployment, and batch inference. Provides comprehensive MLOps capabilities for building, training, and deploying ML models at scale.

3

4

## Capabilities

5

6

### SageMaker Hook

7

8

Core SageMaker client providing ML lifecycle management functionality.

9

10

```python { .api }

11

class SageMakerHook(AwsBaseHook):

12

def __init__(self, aws_conn_id: str = 'aws_default', **kwargs):

13

"""

14

Initialize SageMaker Hook.

15

16

Parameters:

17

- aws_conn_id: AWS connection ID

18

"""

19

20

def create_training_job(self, config: dict, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None) -> dict:

21

"""

22

Create SageMaker training job.

23

24

Parameters:

25

- config: Training job configuration

26

- wait_for_completion: Wait for job completion

27

- print_log: Print training logs

28

- check_interval: Status check interval in seconds

29

- max_ingestion_time: Maximum log ingestion time

30

31

Returns:

32

Training job details

33

"""

34

35

def create_tuning_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:

36

"""

37

Create hyperparameter tuning job.

38

39

Parameters:

40

- config: Tuning job configuration

41

- wait_for_completion: Wait for job completion

42

- check_interval: Status check interval in seconds

43

44

Returns:

45

Tuning job details

46

"""

47

48

def create_model(self, config: dict) -> dict:

49

"""

50

Create SageMaker model.

51

52

Parameters:

53

- config: Model configuration

54

55

Returns:

56

Model details

57

"""

58

59

def create_endpoint_config(self, config: dict) -> dict:

60

"""

61

Create endpoint configuration.

62

63

Parameters:

64

- config: Endpoint configuration

65

66

Returns:

67

Endpoint config details

68

"""

69

70

def create_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:

71

"""

72

Create SageMaker endpoint.

73

74

Parameters:

75

- config: Endpoint configuration

76

- wait_for_completion: Wait for endpoint to be in service

77

- check_interval: Status check interval in seconds

78

79

Returns:

80

Endpoint details

81

"""

82

83

def create_transform_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:

84

"""

85

Create batch transform job.

86

87

Parameters:

88

- config: Transform job configuration

89

- wait_for_completion: Wait for job completion

90

- check_interval: Status check interval in seconds

91

92

Returns:

93

Transform job details

94

"""

95

96

def create_processing_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30) -> dict:

97

"""

98

Create processing job.

99

100

Parameters:

101

- config: Processing job configuration

102

- wait_for_completion: Wait for job completion

103

- check_interval: Status check interval in seconds

104

105

Returns:

106

Processing job details

107

"""

108

109

def describe_training_job(self, name: str) -> dict:

110

"""

111

Get training job details.

112

113

Parameters:

114

- name: Training job name

115

116

Returns:

117

Training job description

118

"""

119

120

def describe_model(self, name: str) -> dict:

121

"""

122

Get model details.

123

124

Parameters:

125

- name: Model name

126

127

Returns:

128

Model description

129

"""

130

131

def describe_endpoint(self, name: str) -> dict:

132

"""

133

Get endpoint details.

134

135

Parameters:

136

- name: Endpoint name

137

138

Returns:

139

Endpoint description

140

"""

141

142

def delete_model(self, name: str) -> None:

143

"""

144

Delete SageMaker model.

145

146

Parameters:

147

- name: Model name

148

"""

149

150

def delete_endpoint_config(self, name: str) -> None:

151

"""

152

Delete endpoint configuration.

153

154

Parameters:

155

- name: Endpoint config name

156

"""

157

158

def delete_endpoint(self, name: str) -> None:

159

"""

160

Delete SageMaker endpoint.

161

162

Parameters:

163

- name: Endpoint name

164

"""

165

```

166

167

### SageMaker Operators

168

169

Task implementations for SageMaker ML operations.

170

171

```python { .api }

172

class SageMakerTrainingOperator(BaseOperator):

173

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: int = None, **kwargs):

174

"""

175

Start SageMaker training job.

176

177

Parameters:

178

- config: Training job configuration

179

- aws_conn_id: AWS connection ID

180

- wait_for_completion: Wait for job completion

181

- print_log: Print training logs

182

- check_interval: Status check interval

183

- max_ingestion_time: Maximum log ingestion time

184

"""

185

186

class SageMakerTuningOperator(BaseOperator):

187

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):

188

"""

189

Start hyperparameter tuning job.

190

191

Parameters:

192

- config: Tuning job configuration

193

- aws_conn_id: AWS connection ID

194

- wait_for_completion: Wait for job completion

195

- check_interval: Status check interval

196

"""

197

198

class SageMakerModelOperator(BaseOperator):

199

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', **kwargs):

200

"""

201

Create SageMaker model.

202

203

Parameters:

204

- config: Model configuration

205

- aws_conn_id: AWS connection ID

206

"""

207

208

class SageMakerEndpointOperator(BaseOperator):

209

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):

210

"""

211

Create SageMaker endpoint.

212

213

Parameters:

214

- config: Endpoint configuration

215

- aws_conn_id: AWS connection ID

216

- wait_for_completion: Wait for endpoint creation

217

- check_interval: Status check interval

218

"""

219

220

class SageMakerTransformOperator(BaseOperator):

221

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):

222

"""

223

Start batch transform job.

224

225

Parameters:

226

- config: Transform job configuration

227

- aws_conn_id: AWS connection ID

228

- wait_for_completion: Wait for job completion

229

- check_interval: Status check interval

230

"""

231

232

class SageMakerProcessingOperator(BaseOperator):

233

def __init__(self, config: dict, aws_conn_id: str = 'aws_default', wait_for_completion: bool = True, check_interval: int = 30, **kwargs):

234

"""

235

Start processing job.

236

237

Parameters:

238

- config: Processing job configuration

239

- aws_conn_id: AWS connection ID

240

- wait_for_completion: Wait for job completion

241

- check_interval: Status check interval

242

"""

243

244

class SageMakerDeleteModelOperator(BaseOperator):

245

def __init__(self, model_name: str, aws_conn_id: str = 'aws_default', **kwargs):

246

"""

247

Delete SageMaker model.

248

249

Parameters:

250

- model_name: Model name to delete

251

- aws_conn_id: AWS connection ID

252

"""

253

```

254

255

### SageMaker Sensors

256

257

Monitoring tasks for SageMaker job and endpoint states.

258

259

```python { .api }

260

class SageMakerTrainingSensor(BaseSensorOperator):

261

def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):

262

"""

263

Wait for SageMaker training job completion.

264

265

Parameters:

266

- job_name: Training job name

267

- aws_conn_id: AWS connection ID

268

"""

269

270

class SageMakerTuningSensor(BaseSensorOperator):

271

def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):

272

"""

273

Wait for hyperparameter tuning job completion.

274

275

Parameters:

276

- job_name: Tuning job name

277

- aws_conn_id: AWS connection ID

278

"""

279

280

class SageMakerTransformSensor(BaseSensorOperator):

281

def __init__(self, job_name: str, aws_conn_id: str = 'aws_default', **kwargs):

282

"""

283

Wait for batch transform job completion.

284

285

Parameters:

286

- job_name: Transform job name

287

- aws_conn_id: AWS connection ID

288

"""

289

290

class SageMakerEndpointSensor(BaseSensorOperator):

291

def __init__(self, endpoint_name: str, aws_conn_id: str = 'aws_default', **kwargs):

292

"""

293

Wait for SageMaker endpoint to be in service.

294

295

Parameters:

296

- endpoint_name: Endpoint name

297

- aws_conn_id: AWS connection ID

298

"""

299

```

300

301

### SageMaker Triggers

302

303

Asynchronous triggers for SageMaker operations.

304

305

```python { .api }

306

class SageMakerTrigger(BaseTrigger):

307

def __init__(self, job_name: str, job_type: str, aws_conn_id: str = 'aws_default', poll_interval: int = 30, **kwargs):

308

"""

309

Asynchronous trigger for SageMaker job monitoring.

310

311

Parameters:

312

- job_name: Job name to monitor

313

- job_type: Type of job ('training', 'tuning', 'transform', 'processing')

314

- aws_conn_id: AWS connection ID

315

- poll_interval: Polling interval in seconds

316

"""

317

```

318

319

## Usage Examples

320

321

### End-to-End ML Pipeline

322

323

```python

324

from airflow import DAG

325

from airflow.providers.amazon.aws.operators.sagemaker import (

326

SageMakerTrainingOperator,

327

SageMakerModelOperator,

328

SageMakerEndpointOperator

329

)

330

331

dag = DAG('ml_pipeline', start_date=datetime(2023, 1, 1))

332

333

# Training job configuration

334

training_config = {

335

'TrainingJobName': 'customer-churn-model-{{ ds }}',

336

'RoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',

337

'AlgorithmSpecification': {

338

'TrainingImage': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',

339

'TrainingInputMode': 'File'

340

},

341

'InputDataConfig': [

342

{

343

'ChannelName': 'training',

344

'DataSource': {

345

'S3DataSource': {

346

'S3DataType': 'S3Prefix',

347

'S3Uri': 's3://ml-training-data/customer-churn/train/',

348

'S3DataDistributionType': 'FullyReplicated'

349

}

350

},

351

'ContentType': 'text/csv',

352

'CompressionType': 'None'

353

},

354

{

355

'ChannelName': 'validation',

356

'DataSource': {

357

'S3DataSource': {

358

'S3DataType': 'S3Prefix',

359

'S3Uri': 's3://ml-training-data/customer-churn/validation/',

360

'S3DataDistributionType': 'FullyReplicated'

361

}

362

},

363

'ContentType': 'text/csv',

364

'CompressionType': 'None'

365

}

366

],

367

'OutputDataConfig': {

368

'S3OutputPath': 's3://ml-model-artifacts/customer-churn/'

369

},

370

'ResourceConfig': {

371

'InstanceType': 'ml.m5.large',

372

'InstanceCount': 1,

373

'VolumeSizeInGB': 30

374

},

375

'StoppingCondition': {

376

'MaxRuntimeInSeconds': 3600

377

},

378

'HyperParameters': {

379

'max_depth': '5',

380

'eta': '0.2',

381

'gamma': '4',

382

'min_child_weight': '6',

383

'subsample': '0.8',

384

'silent': '0',

385

'objective': 'binary:logistic',

386

'num_round': '100'

387

}

388

}

389

390

# Train model

391

train_model = SageMakerTrainingOperator(

392

task_id='train_churn_model',

393

config=training_config,

394

wait_for_completion=True,

395

print_log=True,

396

dag=dag

397

)

398

399

# Create model

400

model_config = {

401

'ModelName': 'customer-churn-model-{{ ds }}',

402

'ExecutionRoleArn': 'arn:aws:iam::123456789012:role/SageMakerExecutionRole',

403

'PrimaryContainer': {

404

'Image': '382416733822.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest',

405

'ModelDataUrl': 's3://ml-model-artifacts/customer-churn/customer-churn-model-{{ ds }}/output/model.tar.gz',

406

'Environment': {

407

'SAGEMAKER_PROGRAM': 'inference.py',

408

'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/code'

409

}

410

}

411

}

412

413

create_model = SageMakerModelOperator(

414

task_id='create_model',

415

config=model_config,

416

dag=dag

417

)

418

419

# Deploy endpoint

420

endpoint_config = {

421

'EndpointName': 'customer-churn-endpoint',

422

'EndpointConfigName': 'customer-churn-config-{{ ds }}',

423

'ProductionVariants': [

424

{

425

'VariantName': 'primary',

426

'ModelName': 'customer-churn-model-{{ ds }}',

427

'InitialInstanceCount': 1,

428

'InstanceType': 'ml.t2.medium',

429

'InitialVariantWeight': 1

430

}

431

]

432

}

433

434

deploy_endpoint = SageMakerEndpointOperator(

435

task_id='deploy_endpoint',

436

config=endpoint_config,

437

wait_for_completion=True,

438

dag=dag

439

)

440

441

train_model >> create_model >> deploy_endpoint

442

```

443

444

### Hyperparameter Tuning

445

446

```python

447

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator

448

449

# Hyperparameter tuning configuration

450

tuning_config = {

451

'HyperParameterTuningJobName': 'xgboost-tuning-{{ ds }}',

452

'HyperParameterTuningJobConfig': {

453

'Strategy': 'Bayesian',

454

'HyperParameterTuningJobObjective': {

455

'Type': 'Maximize',

456

'MetricName': 'validation:auc'

457

},

458

'ResourceLimits': {

459

'MaxNumberOfTrainingJobs': 20,

460

'MaxParallelTrainingJobs': 3

461

},

462

'ParameterRanges': {

463

'IntegerParameterRanges': [

464

{

465

'Name': 'max_depth',

466

'MinValue': '1',

467

'MaxValue': '10'

468

},

469

{

470

'Name': 'num_round',

471

'MinValue': '50',

472

'MaxValue': '200'

473

}

474

],

475

'ContinuousParameterRanges': [

476

{

477

'Name': 'eta',

478

'MinValue': '0.1',

479

'MaxValue': '0.5'

480

},

481

{

482

'Name': 'subsample',

483

'MinValue': '0.5',

484

'MaxValue': '1.0'

485

}

486

]

487

}

488

},

489

'TrainingJobDefinition': training_config

490

}

491

492

tune_hyperparameters = SageMakerTuningOperator(

493

task_id='tune_hyperparameters',

494

config=tuning_config,

495

wait_for_completion=True,

496

dag=dag

497

)

498

```

499

500

### Batch Inference

501

502

```python

503

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator

504

505

# Batch transform configuration

506

transform_config = {

507

'TransformJobName': 'batch-inference-{{ ds }}',

508

'ModelName': 'customer-churn-model-{{ ds }}',

509

'TransformInput': {

510

'DataSource': {

511

'S3DataSource': {

512

'S3DataType': 'S3Prefix',

513

'S3Uri': 's3://ml-inference-data/batch/{{ ds }}/'

514

}

515

},

516

'ContentType': 'text/csv',

517

'SplitType': 'Line'

518

},

519

'TransformOutput': {

520

'S3OutputPath': 's3://ml-inference-results/{{ ds }}/',

521

'Accept': 'text/csv'

522

},

523

'TransformResources': {

524

'InstanceType': 'ml.m5.large',

525

'InstanceCount': 1

526

}

527

}

528

529

batch_inference = SageMakerTransformOperator(

530

task_id='batch_inference',

531

config=transform_config,

532

wait_for_completion=True,

533

dag=dag

534

)

535

```

536

537

## Types

538

539

```python { .api }

540

# SageMaker job states

541

class SageMakerJobState:

542

IN_PROGRESS = 'InProgress'

543

COMPLETED = 'Completed'

544

FAILED = 'Failed'

545

STOPPING = 'Stopping'

546

STOPPED = 'Stopped'

547

548

# Instance types

549

class SageMakerInstanceType:

550

ML_T2_MEDIUM = 'ml.t2.medium'

551

ML_T2_LARGE = 'ml.t2.large'

552

ML_M5_LARGE = 'ml.m5.large'

553

ML_M5_XLARGE = 'ml.m5.xlarge'

554

ML_C5_LARGE = 'ml.c5.large'

555

ML_C5_XLARGE = 'ml.c5.xlarge'

556

ML_P3_2XLARGE = 'ml.p3.2xlarge'

557

ML_P3_8XLARGE = 'ml.p3.8xlarge'

558

559

# Training job configuration

560

class TrainingJobConfig:

561

training_job_name: str

562

role_arn: str

563

algorithm_specification: dict

564

input_data_config: list

565

output_data_config: dict

566

resource_config: dict

567

stopping_condition: dict

568

hyper_parameters: dict = None

569

vpc_config: dict = None

570

tags: list = None

571

enable_network_isolation: bool = False

572

enable_inter_container_traffic_encryption: bool = False

573

enable_managed_spot_training: bool = False

574

checkpoint_config: dict = None

575

debug_hook_config: dict = None

576

debug_rule_configurations: list = None

577

tensor_board_output_config: dict = None

578

experiment_config: dict = None

579

profiler_config: dict = None

580

profiler_rule_configurations: list = None

581

environment: dict = None

582

retry_strategy: dict = None

583

584

# Model configuration

585

class ModelConfig:

586

model_name: str

587

execution_role_arn: str

588

primary_container: dict = None

589

containers: list = None

590

inference_execution_config: dict = None

591

tags: list = None

592

vpc_config: dict = None

593

enable_network_isolation: bool = False

594

595

# Endpoint configuration

596

class EndpointConfig:

597

endpoint_name: str

598

endpoint_config_name: str

599

production_variants: list

600

data_capture_config: dict = None

601

tags: list = None

602

kms_key_id: str = None

603

async_inference_config: dict = None

604

explainer_config: dict = None

605

shadow_production_variants: list = None

606

```