0
# Core Training and Model Management
1
2
Fundamental classes and functions for training models and managing deployments in SageMaker. These core components provide the foundation for all ML workflows including training, deployment, and inference.
3
4
## Capabilities
5
6
### Estimator Base Classes
7
8
Core estimator classes that provide the foundation for all training workflows, handling AWS service integration, resource management, and deployment.
9
10
```python { .api }
11
class Estimator:
12
"""
13
Base class for all SageMaker estimators.
14
15
Parameters:
16
- image_uri (str): Docker image URI for training
17
- role (str): IAM role ARN with SageMaker permissions
18
- instance_count (int): Number of training instances
19
- instance_type (str): EC2 instance type for training
20
- output_path (str, optional): S3 path for model artifacts
21
- sagemaker_session (Session, optional): SageMaker session
22
- hyperparameters (dict, optional): Algorithm hyperparameters
23
- environment (dict, optional): Environment variables
24
- max_run (int, optional): Maximum training time in seconds
25
- input_mode (str, optional): Training input mode ('File' or 'Pipe')
26
- vpc_config (dict, optional): VPC configuration
27
- metric_definitions (list, optional): Custom metric definitions
28
"""
29
def __init__(self, image_uri: str, role: str, instance_count: int,
30
instance_type: str, output_path: str = None,
31
sagemaker_session: Session = None, **kwargs): ...
32
33
def fit(self, inputs, wait: bool = True, logs: bool = True,
34
job_name: str = None, experiment_config: dict = None): ...
35
36
def deploy(self, initial_instance_count: int, instance_type: str,
37
serializer: BaseSerializer = None, deserializer: BaseDeserializer = None,
38
accelerator_type: str = None, endpoint_name: str = None,
39
inference_component_name: str = None, **kwargs) -> Predictor: ...
40
41
def create_model(self, vpc_config_override: dict = None, **kwargs) -> Model: ...
42
43
def delete_endpoint(self, endpoint_name: str = None): ...
44
45
class Framework(Estimator):
46
"""
47
Base class for framework-specific estimators.
48
49
Parameters:
50
- entry_point (str): Path to training script
51
- source_dir (str, optional): Directory containing training code
52
- dependencies (list, optional): List of additional dependencies
53
- code_location (str, optional): S3 location for training code
54
"""
55
def __init__(self, entry_point: str, source_dir: str = None,
56
dependencies: list = None, **kwargs): ...
57
```
58
59
### Model Management
60
61
Model classes for deploying trained models, managing model artifacts, and creating inference endpoints.
62
63
```python { .api }
64
class Model:
65
"""
66
Model class for deploying trained models to SageMaker endpoints.
67
68
Parameters:
69
- image_uri (str): Docker image URI for inference
70
- model_data (str): S3 path to model artifacts
71
- role (str): IAM role ARN with SageMaker permissions
72
- predictor_cls (type, optional): Predictor class for deployment
73
- env (dict, optional): Environment variables
74
- name (str, optional): Model name
75
- vpc_config (dict, optional): VPC configuration
76
- sagemaker_session (Session, optional): SageMaker session
77
"""
78
def __init__(self, image_uri: str, model_data: str, role: str,
79
predictor_cls: type = None, env: dict = None, **kwargs): ...
80
81
def deploy(self, initial_instance_count: int, instance_type: str,
82
serializer: BaseSerializer = None, deserializer: BaseDeserializer = None,
83
accelerator_type: str = None, endpoint_name: str = None,
84
inference_component_name: str = None, **kwargs) -> Predictor: ...
85
86
def create(self, instance_type: str = None, accelerator_type: str = None): ...
87
88
def delete_model(self): ...
89
90
def register(self, content_types: list, response_types: list,
91
inference_instances: list = None, transform_instances: list = None,
92
model_package_group_name: str = None, **kwargs): ...
93
94
class ModelPackage:
95
"""
96
Model package class for versioned model management and deployment.
97
98
Parameters:
99
- role (str): IAM role ARN
100
- model_data (str, optional): S3 path to model artifacts
101
- image_uri (str, optional): Docker image URI
102
- model_package_arn (str, optional): Existing model package ARN
103
"""
104
def __init__(self, role: str, model_data: str = None, image_uri: str = None,
105
model_package_arn: str = None, **kwargs): ...
106
107
def deploy(self, initial_instance_count: int, instance_type: str, **kwargs) -> Predictor: ...
108
109
class PipelineModel:
110
"""
111
Pipeline model for chaining multiple models in sequence.
112
113
Parameters:
114
- name (str): Pipeline model name
115
- role (str): IAM role ARN
116
- models (list): List of Model objects to chain
117
"""
118
def __init__(self, name: str, role: str, models: list, **kwargs): ...
119
120
def deploy(self, initial_instance_count: int, instance_type: str, **kwargs) -> Predictor: ...
121
```
122
123
### Prediction and Inference
124
125
Predictor classes for making real-time and batch predictions against deployed models.
126
127
```python { .api }
128
class Predictor:
129
"""
130
Base predictor class for real-time inference.
131
132
Parameters:
133
- endpoint_name (str): SageMaker endpoint name
134
- sagemaker_session (Session, optional): SageMaker session
135
- serializer (BaseSerializer, optional): Input serializer
136
- deserializer (BaseDeserializer, optional): Output deserializer
137
"""
138
def __init__(self, endpoint_name: str, sagemaker_session: Session = None,
139
serializer: BaseSerializer = None, deserializer: BaseDeserializer = None): ...
140
141
def predict(self, data, initial_args: dict = None, target_model: str = None,
142
target_variant: str = None, inference_id: str = None): ...
143
144
def update_endpoint(self, initial_instance_count: int = None,
145
instance_type: str = None, **kwargs): ...
146
147
def delete_endpoint(self, delete_endpoint_config: bool = True): ...
148
149
def delete_model(self): ...
150
151
def enable_data_capture(self, sampling_percentage: int = 20,
152
capture_options: list = None): ...
153
154
def disable_data_capture(): ...
155
156
class AsyncPredictor:
157
"""
158
Async predictor for asynchronous inference.
159
160
Parameters:
161
- predictor (Predictor): Base predictor instance
162
- name (str, optional): Async inference name
163
"""
164
def __init__(self, predictor: Predictor, name: str = None): ...
165
166
def predict_async(self, input_path: str, initial_args: dict = None,
167
inference_id: str = None) -> str: ...
168
169
def describe_async_inference_result(self, result_path: str) -> dict: ...
170
```
171
172
### Session Management
173
174
Session classes for managing AWS credentials, regions, and SageMaker service configurations.
175
176
```python { .api }
177
class Session:
178
"""
179
SageMaker session for managing service interactions.
180
181
Parameters:
182
- boto_session (boto3.Session, optional): Boto3 session
183
- sagemaker_client (boto3.Client, optional): SageMaker client
184
- sagemaker_runtime_client (boto3.Client, optional): SageMaker Runtime client
185
- default_bucket (str, optional): Default S3 bucket name
186
- s3_resource (boto3.Resource, optional): S3 resource
187
- settings (SessionSettings, optional): Session settings
188
"""
189
def __init__(self, boto_session: 'boto3.Session' = None,
190
sagemaker_client: 'boto3.Client' = None,
191
sagemaker_runtime_client: 'boto3.Client' = None, **kwargs): ...
192
193
def upload_data(self, path: str, bucket: str = None, key_prefix: str = None,
194
callback: callable = None, extra_args: dict = None) -> str: ...
195
196
def download_data(self, path: str, bucket: str, key_prefix: str,
197
extra_args: dict = None): ...
198
199
def create_training_job(self, **kwargs) -> dict: ...
200
201
def create_model(self, **kwargs) -> dict: ...
202
203
def create_endpoint_config(self, **kwargs) -> dict: ...
204
205
def create_endpoint(self, **kwargs) -> dict: ...
206
207
def wait_for_training_job(self, job_name: str, poll: int = 5): ...
208
209
def wait_for_endpoint(self, endpoint_name: str, poll: int = 30): ...
210
211
def default_bucket(self) -> str: ...
212
213
def delete_endpoint(self, endpoint_name: str): ...
214
215
def delete_endpoint_config(self, endpoint_config_name: str): ...
216
217
def delete_model(self, model_name: str): ...
218
219
class LocalSession(Session):
220
"""
221
Local session for local development and testing.
222
"""
223
def __init__(self, **kwargs): ...
224
225
def get_execution_role() -> str:
226
"""
227
Get the IAM execution role from the SageMaker notebook instance or environment.
228
229
Returns:
230
str: IAM role ARN
231
232
Raises:
233
ValueError: If role cannot be determined
234
"""
235
236
def container_def(image_uri: str, model_data_url: str = None, env: dict = None,
237
container_hostname: str = None, image_config: dict = None) -> dict:
238
"""
239
Create container definition for multi-model endpoints.
240
241
Parameters:
242
- image_uri (str): Docker image URI
243
- model_data_url (str, optional): S3 path to model artifacts
244
- env (dict, optional): Environment variables
245
- container_hostname (str, optional): Container hostname
246
- image_config (dict, optional): Image configuration
247
248
Returns:
249
dict: Container definition
250
"""
251
252
def pipeline_container_def(models: list, instance_type: str = None) -> list:
253
"""
254
Create container definitions for pipeline models.
255
256
Parameters:
257
- models (list): List of Model objects
258
- instance_type (str, optional): Instance type
259
260
Returns:
261
list: List of container definitions
262
"""
263
264
def production_variant(model_name: str, instance_type: str, initial_instance_count: int = 1,
265
variant_name: str = "AllTraffic", initial_weight: int = 1,
266
accelerator_type: str = None, serverless_inference_config: dict = None) -> dict:
267
"""
268
Create production variant configuration for endpoints.
269
270
Parameters:
271
- model_name (str): SageMaker model name
272
- instance_type (str): EC2 instance type
273
- initial_instance_count (int): Initial instance count
274
- variant_name (str): Variant name
275
- initial_weight (int): Traffic weight
276
- accelerator_type (str, optional): Accelerator type
277
- serverless_inference_config (dict, optional): Serverless config
278
279
Returns:
280
dict: Production variant configuration
281
"""
282
283
def get_model_package_args(content_types: list, response_types: list,
284
inference_instances: list = None, transform_instances: list = None) -> dict:
285
"""
286
Get model package arguments for registration.
287
288
Parameters:
289
- content_types (list): Supported content types
290
- response_types (list): Supported response types
291
- inference_instances (list, optional): Supported inference instances
292
- transform_instances (list, optional): Supported transform instances
293
294
Returns:
295
dict: Model package arguments
296
"""
297
```
298
299
## Usage Examples
300
301
### Basic Training and Deployment
302
303
```python
304
import sagemaker
305
from sagemaker import Estimator, Session, get_execution_role
306
307
# Set up session and role
308
session = Session()
309
role = get_execution_role()
310
311
# Create custom estimator
312
estimator = Estimator(
313
image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/my-algorithm:latest",
314
role=role,
315
instance_count=1,
316
instance_type="ml.m5.large",
317
output_path="s3://my-bucket/model-artifacts"
318
)
319
320
# Train the model
321
estimator.fit({"training": "s3://my-bucket/training-data"})
322
323
# Deploy the model
324
predictor = estimator.deploy(
325
initial_instance_count=1,
326
instance_type="ml.m5.large"
327
)
328
329
# Make predictions
330
result = predictor.predict(test_data)
331
332
# Clean up
333
predictor.delete_endpoint()
334
```
335
336
### Model Registration and Deployment
337
338
```python
339
from sagemaker import Model, ModelPackage
340
341
# Create model from artifacts
342
model = Model(
343
image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/inference:latest",
344
model_data="s3://my-bucket/model.tar.gz",
345
role=role
346
)
347
348
# Register model package
349
model_package = model.register(
350
content_types=["application/json"],
351
response_types=["application/json"],
352
inference_instances=["ml.m5.large"],
353
model_package_group_name="my-model-group"
354
)
355
356
# Deploy from model package
357
predictor = model_package.deploy(
358
initial_instance_count=1,
359
instance_type="ml.m5.large"
360
)
361
```