0
# Framework-Specific Training
1
2
Framework-specific estimators and models for popular machine learning frameworks including PyTorch, TensorFlow, Scikit-learn, XGBoost, Hugging Face, and MXNet. Each framework provides optimized training containers, deployment options, and processing capabilities.
3
4
## Capabilities
5
6
### PyTorch
7
8
PyTorch estimator for training and deploying PyTorch models with automatic environment setup and dependency management.
9
10
```python { .api }
11
class PyTorch(Framework):
12
"""
13
PyTorch estimator for training PyTorch models on SageMaker.
14
15
Parameters:
16
- entry_point (str): Path to Python training script
17
- framework_version (str): PyTorch version (e.g., "2.1.0")
18
- py_version (str): Python version ("py39", "py310", "py311")
19
- instance_type (str): EC2 instance type for training
20
- instance_count (int): Number of training instances
21
- role (str): IAM role ARN
22
- source_dir (str, optional): Directory containing training code
23
- dependencies (list, optional): Additional pip dependencies
24
- use_mpi (bool, optional): Enable MPI for distributed training
25
- use_spot_instances (bool, optional): Use spot instances
26
- distribution (dict, optional): Distributed training configuration
27
"""
28
def __init__(self, entry_point: str, framework_version: str, py_version: str,
29
instance_type: str, instance_count: int, role: str, **kwargs): ...
30
31
class PyTorchModel(FrameworkModel):
32
"""
33
PyTorch model for deployment.
34
35
Parameters:
36
- model_data (str): S3 path to model artifacts
37
- role (str): IAM role ARN
38
- entry_point (str): Path to inference script
39
- framework_version (str): PyTorch version
40
- py_version (str): Python version
41
"""
42
def __init__(self, model_data: str, role: str, entry_point: str,
43
framework_version: str, py_version: str, **kwargs): ...
44
45
class PyTorchPredictor(Predictor):
46
"""
47
PyTorch predictor for real-time inference.
48
"""
49
def __init__(self, endpoint_name: str, **kwargs): ...
50
51
class PyTorchProcessor(ScriptProcessor):
52
"""
53
PyTorch processor for data processing jobs.
54
55
Parameters:
56
- framework_version (str): PyTorch version
57
- py_version (str): Python version
58
- instance_type (str): EC2 instance type
59
- instance_count (int): Number of processing instances
60
- role (str): IAM role ARN
61
"""
62
def __init__(self, framework_version: str, py_version: str,
63
instance_type: str, instance_count: int, role: str, **kwargs): ...
64
```
65
66
### TensorFlow
67
68
TensorFlow estimator for training and deploying TensorFlow models with support for TensorFlow Serving and custom inference.
69
70
```python { .api }
71
class TensorFlow(Framework):
72
"""
73
TensorFlow estimator for training TensorFlow models on SageMaker.
74
75
Parameters:
76
- entry_point (str): Path to Python training script
77
- framework_version (str): TensorFlow version (e.g., "2.13.0")
78
- py_version (str): Python version ("py39", "py310", "py311")
79
- instance_type (str): EC2 instance type for training
80
- instance_count (int): Number of training instances
81
- role (str): IAM role ARN
82
- distribution (dict, optional): Distributed training configuration
83
- model_dir (str, optional): S3 path for model checkpoints
84
"""
85
def __init__(self, entry_point: str, framework_version: str, py_version: str,
86
instance_type: str, instance_count: int, role: str, **kwargs): ...
87
88
class TensorFlowModel(FrameworkModel):
89
"""
90
TensorFlow model for deployment.
91
92
Parameters:
93
- model_data (str): S3 path to model artifacts
94
- role (str): IAM role ARN
95
- framework_version (str): TensorFlow version
96
- entry_point (str, optional): Path to inference script
97
"""
98
def __init__(self, model_data: str, role: str, framework_version: str,
99
entry_point: str = None, **kwargs): ...
100
101
class TensorFlowPredictor(Predictor):
102
"""
103
TensorFlow predictor for real-time inference.
104
"""
105
def __init__(self, endpoint_name: str, **kwargs): ...
106
107
class TensorFlowProcessor(ScriptProcessor):
108
"""
109
TensorFlow processor for data processing jobs.
110
111
Parameters:
112
- framework_version (str): TensorFlow version
113
- py_version (str): Python version
114
- instance_type (str): EC2 instance type
115
- instance_count (int): Number of processing instances
116
- role (str): IAM role ARN
117
"""
118
def __init__(self, framework_version: str, py_version: str,
119
instance_type: str, instance_count: int, role: str, **kwargs): ...
120
```
121
122
### Scikit-learn
123
124
Scikit-learn estimator for training and deploying scikit-learn models with pre-built containers and automatic dependency management.
125
126
```python { .api }
127
class SKLearn(Framework):
128
"""
129
Scikit-learn estimator for training scikit-learn models on SageMaker.
130
131
Parameters:
132
- entry_point (str): Path to Python training script
133
- framework_version (str): Scikit-learn version (e.g., "1.2-1")
134
- py_version (str): Python version ("py39", "py310")
135
- instance_type (str): EC2 instance type for training
136
- role (str): IAM role ARN
137
- script_mode (bool, optional): Enable script mode
138
"""
139
def __init__(self, entry_point: str, framework_version: str, py_version: str,
140
instance_type: str, role: str, **kwargs): ...
141
142
class SKLearnModel(FrameworkModel):
143
"""
144
Scikit-learn model for deployment.
145
146
Parameters:
147
- model_data (str): S3 path to model artifacts
148
- role (str): IAM role ARN
149
- entry_point (str): Path to inference script
150
- framework_version (str): Scikit-learn version
151
"""
152
def __init__(self, model_data: str, role: str, entry_point: str,
153
framework_version: str, **kwargs): ...
154
155
class SKLearnPredictor(Predictor):
156
"""
157
Scikit-learn predictor for real-time inference.
158
"""
159
def __init__(self, endpoint_name: str, **kwargs): ...
160
161
class SKLearnProcessor(ScriptProcessor):
162
"""
163
Scikit-learn processor for data processing jobs.
164
165
Parameters:
166
- framework_version (str): Scikit-learn version
167
- instance_type (str): EC2 instance type
168
- instance_count (int): Number of processing instances
169
- role (str): IAM role ARN
170
"""
171
def __init__(self, framework_version: str, instance_type: str,
172
instance_count: int, role: str, **kwargs): ...
173
```
174
175
### XGBoost
176
177
XGBoost estimator for training and deploying XGBoost models with built-in algorithm support and custom training scripts.
178
179
```python { .api }
180
class XGBoost(Framework):
181
"""
182
XGBoost estimator for training XGBoost models on SageMaker.
183
184
Parameters:
185
- entry_point (str): Path to Python training script
186
- framework_version (str): XGBoost version (e.g., "1.7-1")
187
- py_version (str): Python version ("py39", "py310")
188
- instance_type (str): EC2 instance type for training
189
- role (str): IAM role ARN
190
"""
191
def __init__(self, entry_point: str, framework_version: str, py_version: str,
192
instance_type: str, role: str, **kwargs): ...
193
194
class XGBoostModel(FrameworkModel):
195
"""
196
XGBoost model for deployment.
197
198
Parameters:
199
- model_data (str): S3 path to model artifacts
200
- role (str): IAM role ARN
201
- entry_point (str): Path to inference script
202
- framework_version (str): XGBoost version
203
"""
204
def __init__(self, model_data: str, role: str, entry_point: str,
205
framework_version: str, **kwargs): ...
206
207
class XGBoostPredictor(Predictor):
208
"""
209
XGBoost predictor for real-time inference.
210
"""
211
def __init__(self, endpoint_name: str, **kwargs): ...
212
213
class XGBoostProcessor(ScriptProcessor):
214
"""
215
XGBoost processor for data processing jobs.
216
217
Parameters:
218
- framework_version (str): XGBoost version
219
- instance_type (str): EC2 instance type
220
- instance_count (int): Number of processing instances
221
- role (str): IAM role ARN
222
"""
223
def __init__(self, framework_version: str, instance_type: str,
224
instance_count: int, role: str, **kwargs): ...
225
226
# Framework constant
227
XGBOOST_NAME = "xgboost"
228
```
229
230
### Hugging Face
231
232
Hugging Face estimator for training and deploying transformer models with automatic model hub integration and optimized inference.
233
234
```python { .api }
235
class HuggingFace(Framework):
236
"""
237
Hugging Face estimator for training transformer models on SageMaker.
238
239
Parameters:
240
- entry_point (str): Path to Python training script
241
- transformers_version (str): Transformers version (e.g., "4.36.0")
242
- pytorch_version (str): PyTorch version (e.g., "2.1.0")
243
- py_version (str): Python version ("py39", "py310", "py311")
244
- instance_type (str): EC2 instance type for training
245
- role (str): IAM role ARN
246
- hyperparameters (dict, optional): Training hyperparameters
247
- distribution (dict, optional): Distributed training configuration
248
"""
249
def __init__(self, entry_point: str, transformers_version: str, pytorch_version: str,
250
py_version: str, instance_type: str, role: str, **kwargs): ...
251
252
class HuggingFaceModel(FrameworkModel):
253
"""
254
Hugging Face model for deployment.
255
256
Parameters:
257
- model_data (str, optional): S3 path to model artifacts
258
- role (str): IAM role ARN
259
- transformers_version (str): Transformers version
260
- pytorch_version (str): PyTorch version
261
- env (dict, optional): Environment variables including HF_MODEL_ID
262
"""
263
def __init__(self, role: str, transformers_version: str, pytorch_version: str,
264
model_data: str = None, **kwargs): ...
265
266
class HuggingFacePredictor(Predictor):
267
"""
268
Hugging Face predictor for real-time inference.
269
"""
270
def __init__(self, endpoint_name: str, **kwargs): ...
271
272
class HuggingFaceProcessor(ScriptProcessor):
273
"""
274
Hugging Face processor for data processing jobs.
275
276
Parameters:
277
- transformers_version (str): Transformers version
278
- pytorch_version (str): PyTorch version
279
- instance_type (str): EC2 instance type
280
- instance_count (int): Number of processing instances
281
- role (str): IAM role ARN
282
"""
283
def __init__(self, transformers_version: str, pytorch_version: str,
284
instance_type: str, instance_count: int, role: str, **kwargs): ...
285
286
def get_huggingface_llm_image_uri(backend: str = "huggingface", region: str = None,
287
version: str = None) -> str:
288
"""
289
Get Hugging Face Large Language Model container image URI.
290
291
Parameters:
292
- backend (str): Inference backend ("huggingface", "lmi")
293
- region (str, optional): AWS region
294
- version (str, optional): Container version
295
296
Returns:
297
str: Container image URI
298
"""
299
```
300
301
### MXNet
302
303
MXNet estimator for training and deploying Apache MXNet models with Gluon support.
304
305
```python { .api }
306
class MXNet(Framework):
307
"""
308
MXNet estimator for training MXNet models on SageMaker.
309
310
Parameters:
311
- entry_point (str): Path to Python training script
312
- framework_version (str): MXNet version (e.g., "1.9.0")
313
- py_version (str): Python version ("py38", "py39")
314
- instance_type (str): EC2 instance type for training
315
- role (str): IAM role ARN
316
- distribution (dict, optional): Distributed training configuration
317
"""
318
def __init__(self, entry_point: str, framework_version: str, py_version: str,
319
instance_type: str, role: str, **kwargs): ...
320
321
class MXNetModel(FrameworkModel):
322
"""
323
MXNet model for deployment.
324
325
Parameters:
326
- model_data (str): S3 path to model artifacts
327
- role (str): IAM role ARN
328
- entry_point (str): Path to inference script
329
- framework_version (str): MXNet version
330
"""
331
def __init__(self, model_data: str, role: str, entry_point: str,
332
framework_version: str, **kwargs): ...
333
334
class MXNetPredictor(Predictor):
335
"""
336
MXNet predictor for real-time inference.
337
"""
338
def __init__(self, endpoint_name: str, **kwargs): ...
339
340
class MXNetProcessor(ScriptProcessor):
341
"""
342
MXNet processor for data processing jobs.
343
344
Parameters:
345
- framework_version (str): MXNet version
346
- instance_type (str): EC2 instance type
347
- instance_count (int): Number of processing instances
348
- role (str): IAM role ARN
349
"""
350
def __init__(self, framework_version: str, instance_type: str,
351
instance_count: int, role: str, **kwargs): ...
352
```
353
354
### Training Compiler Configuration
355
356
Optimization configuration for accelerating training performance across frameworks.
357
358
```python { .api }
359
class TrainingCompilerConfig:
360
"""
361
Configuration for SageMaker Training Compiler optimization.
362
363
Parameters:
364
- enabled (bool): Enable training compiler
365
- debug (bool, optional): Enable debug mode
366
"""
367
def __init__(self, enabled: bool = True, debug: bool = False): ...
368
```
369
370
## Usage Examples
371
372
### PyTorch Training and Deployment
373
374
```python
375
from sagemaker.pytorch import PyTorch, PyTorchModel
376
377
# Create PyTorch estimator
378
pytorch_estimator = PyTorch(
379
entry_point="train.py",
380
source_dir="src",
381
role=role,
382
instance_type="ml.p3.2xlarge",
383
instance_count=1,
384
framework_version="2.1.0",
385
py_version="py310",
386
hyperparameters={
387
"epochs": 10,
388
"batch_size": 32
389
}
390
)
391
392
# Train the model
393
pytorch_estimator.fit({"training": training_data_path})
394
395
# Deploy using the estimator
396
predictor = pytorch_estimator.deploy(
397
initial_instance_count=1,
398
instance_type="ml.m5.large"
399
)
400
401
# Or create model and deploy separately
402
pytorch_model = PyTorchModel(
403
model_data=pytorch_estimator.model_data,
404
role=role,
405
entry_point="inference.py",
406
framework_version="2.1.0",
407
py_version="py310"
408
)
409
410
predictor = pytorch_model.deploy(
411
initial_instance_count=1,
412
instance_type="ml.m5.large"
413
)
414
```
415
416
### Hugging Face Training
417
418
```python
419
from sagemaker.huggingface import HuggingFace
420
421
# Create Hugging Face estimator
422
huggingface_estimator = HuggingFace(
423
entry_point="train.py",
424
role=role,
425
instance_type="ml.p3.2xlarge",
426
instance_count=1,
427
transformers_version="4.36.0",
428
pytorch_version="2.1.0",
429
py_version="py310",
430
hyperparameters={
431
"model_name_or_path": "bert-base-uncased",
432
"num_train_epochs": 3,
433
"per_device_train_batch_size": 16
434
}
435
)
436
437
# Train the model
438
huggingface_estimator.fit({"train": train_data, "test": test_data})
439
440
# Deploy for inference
441
predictor = huggingface_estimator.deploy(
442
initial_instance_count=1,
443
instance_type="ml.g4dn.xlarge"
444
)
445
```
446
447
### Distributed Training Example
448
449
```python
450
from sagemaker.pytorch import PyTorch
451
452
# Configure distributed training
453
distribution = {
454
"torch_distributed": {
455
"enabled": True
456
}
457
}
458
459
pytorch_estimator = PyTorch(
460
entry_point="train_distributed.py",
461
role=role,
462
instance_type="ml.p3.8xlarge",
463
instance_count=2, # Multiple instances
464
framework_version="2.1.0",
465
py_version="py310",
466
distribution=distribution,
467
hyperparameters={
468
"backend": "nccl"
469
}
470
)
471
472
pytorch_estimator.fit(training_data_path)
473
```