or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

amazon-algorithms.mdautoml.mdcore-training.mddata-processing.mddebugging-profiling.mdexperiments.mdframework-training.mdhyperparameter-tuning.mdindex.mdmodel-monitoring.mdmodel-serving.mdremote-functions.md

framework-training.mddocs/

0

# Framework-Specific Training

1

2

Framework-specific estimators and models for popular machine learning frameworks including PyTorch, TensorFlow, Scikit-learn, XGBoost, Hugging Face, and MXNet. Each framework provides optimized training containers, deployment options, and processing capabilities.

3

4

## Capabilities

5

6

### PyTorch

7

8

PyTorch estimator for training and deploying PyTorch models with automatic environment setup and dependency management.

9

10

```python { .api }

11

class PyTorch(Framework):

12

"""

13

PyTorch estimator for training PyTorch models on SageMaker.

14

15

Parameters:

16

- entry_point (str): Path to Python training script

17

- framework_version (str): PyTorch version (e.g., "2.1.0")

18

- py_version (str): Python version ("py39", "py310", "py311")

19

- instance_type (str): EC2 instance type for training

20

- instance_count (int): Number of training instances

21

- role (str): IAM role ARN

22

- source_dir (str, optional): Directory containing training code

23

- dependencies (list, optional): Additional pip dependencies

24

- use_mpi (bool, optional): Enable MPI for distributed training

25

- use_spot_instances (bool, optional): Use spot instances

26

- distribution (dict, optional): Distributed training configuration

27

"""

28

def __init__(self, entry_point: str, framework_version: str, py_version: str,

29

instance_type: str, instance_count: int, role: str, **kwargs): ...

30

31

class PyTorchModel(FrameworkModel):

32

"""

33

PyTorch model for deployment.

34

35

Parameters:

36

- model_data (str): S3 path to model artifacts

37

- role (str): IAM role ARN

38

- entry_point (str): Path to inference script

39

- framework_version (str): PyTorch version

40

- py_version (str): Python version

41

"""

42

def __init__(self, model_data: str, role: str, entry_point: str,

43

framework_version: str, py_version: str, **kwargs): ...

44

45

class PyTorchPredictor(Predictor):

46

"""

47

PyTorch predictor for real-time inference.

48

"""

49

def __init__(self, endpoint_name: str, **kwargs): ...

50

51

class PyTorchProcessor(ScriptProcessor):

52

"""

53

PyTorch processor for data processing jobs.

54

55

Parameters:

56

- framework_version (str): PyTorch version

57

- py_version (str): Python version

58

- instance_type (str): EC2 instance type

59

- instance_count (int): Number of processing instances

60

- role (str): IAM role ARN

61

"""

62

def __init__(self, framework_version: str, py_version: str,

63

instance_type: str, instance_count: int, role: str, **kwargs): ...

64

```

65

66

### TensorFlow

67

68

TensorFlow estimator for training and deploying TensorFlow models with support for TensorFlow Serving and custom inference.

69

70

```python { .api }

71

class TensorFlow(Framework):

72

"""

73

TensorFlow estimator for training TensorFlow models on SageMaker.

74

75

Parameters:

76

- entry_point (str): Path to Python training script

77

- framework_version (str): TensorFlow version (e.g., "2.13.0")

78

- py_version (str): Python version ("py39", "py310", "py311")

79

- instance_type (str): EC2 instance type for training

80

- instance_count (int): Number of training instances

81

- role (str): IAM role ARN

82

- distribution (dict, optional): Distributed training configuration

83

- model_dir (str, optional): S3 path for model checkpoints

84

"""

85

def __init__(self, entry_point: str, framework_version: str, py_version: str,

86

instance_type: str, instance_count: int, role: str, **kwargs): ...

87

88

class TensorFlowModel(FrameworkModel):

89

"""

90

TensorFlow model for deployment.

91

92

Parameters:

93

- model_data (str): S3 path to model artifacts

94

- role (str): IAM role ARN

95

- framework_version (str): TensorFlow version

96

- entry_point (str, optional): Path to inference script

97

"""

98

def __init__(self, model_data: str, role: str, framework_version: str,

99

entry_point: str = None, **kwargs): ...

100

101

class TensorFlowPredictor(Predictor):

102

"""

103

TensorFlow predictor for real-time inference.

104

"""

105

def __init__(self, endpoint_name: str, **kwargs): ...

106

107

class TensorFlowProcessor(ScriptProcessor):

108

"""

109

TensorFlow processor for data processing jobs.

110

111

Parameters:

112

- framework_version (str): TensorFlow version

113

- py_version (str): Python version

114

- instance_type (str): EC2 instance type

115

- instance_count (int): Number of processing instances

116

- role (str): IAM role ARN

117

"""

118

def __init__(self, framework_version: str, py_version: str,

119

instance_type: str, instance_count: int, role: str, **kwargs): ...

120

```

121

122

### Scikit-learn

123

124

Scikit-learn estimator for training and deploying scikit-learn models with pre-built containers and automatic dependency management.

125

126

```python { .api }

127

class SKLearn(Framework):

128

"""

129

Scikit-learn estimator for training scikit-learn models on SageMaker.

130

131

Parameters:

132

- entry_point (str): Path to Python training script

133

- framework_version (str): Scikit-learn version (e.g., "1.2-1")

134

- py_version (str): Python version ("py39", "py310")

135

- instance_type (str): EC2 instance type for training

136

- role (str): IAM role ARN

137

- script_mode (bool, optional): Enable script mode

138

"""

139

def __init__(self, entry_point: str, framework_version: str, py_version: str,

140

instance_type: str, role: str, **kwargs): ...

141

142

class SKLearnModel(FrameworkModel):

143

"""

144

Scikit-learn model for deployment.

145

146

Parameters:

147

- model_data (str): S3 path to model artifacts

148

- role (str): IAM role ARN

149

- entry_point (str): Path to inference script

150

- framework_version (str): Scikit-learn version

151

"""

152

def __init__(self, model_data: str, role: str, entry_point: str,

153

framework_version: str, **kwargs): ...

154

155

class SKLearnPredictor(Predictor):

156

"""

157

Scikit-learn predictor for real-time inference.

158

"""

159

def __init__(self, endpoint_name: str, **kwargs): ...

160

161

class SKLearnProcessor(ScriptProcessor):

162

"""

163

Scikit-learn processor for data processing jobs.

164

165

Parameters:

166

- framework_version (str): Scikit-learn version

167

- instance_type (str): EC2 instance type

168

- instance_count (int): Number of processing instances

169

- role (str): IAM role ARN

170

"""

171

def __init__(self, framework_version: str, instance_type: str,

172

instance_count: int, role: str, **kwargs): ...

173

```

174

175

### XGBoost

176

177

XGBoost estimator for training and deploying XGBoost models with built-in algorithm support and custom training scripts.

178

179

```python { .api }

180

class XGBoost(Framework):

181

"""

182

XGBoost estimator for training XGBoost models on SageMaker.

183

184

Parameters:

185

- entry_point (str): Path to Python training script

186

- framework_version (str): XGBoost version (e.g., "1.7-1")

187

- py_version (str): Python version ("py39", "py310")

188

- instance_type (str): EC2 instance type for training

189

- role (str): IAM role ARN

190

"""

191

def __init__(self, entry_point: str, framework_version: str, py_version: str,

192

instance_type: str, role: str, **kwargs): ...

193

194

class XGBoostModel(FrameworkModel):

195

"""

196

XGBoost model for deployment.

197

198

Parameters:

199

- model_data (str): S3 path to model artifacts

200

- role (str): IAM role ARN

201

- entry_point (str): Path to inference script

202

- framework_version (str): XGBoost version

203

"""

204

def __init__(self, model_data: str, role: str, entry_point: str,

205

framework_version: str, **kwargs): ...

206

207

class XGBoostPredictor(Predictor):

208

"""

209

XGBoost predictor for real-time inference.

210

"""

211

def __init__(self, endpoint_name: str, **kwargs): ...

212

213

class XGBoostProcessor(ScriptProcessor):

214

"""

215

XGBoost processor for data processing jobs.

216

217

Parameters:

218

- framework_version (str): XGBoost version

219

- instance_type (str): EC2 instance type

220

- instance_count (int): Number of processing instances

221

- role (str): IAM role ARN

222

"""

223

def __init__(self, framework_version: str, instance_type: str,

224

instance_count: int, role: str, **kwargs): ...

225

226

# Framework constant

227

XGBOOST_NAME = "xgboost"

228

```

229

230

### Hugging Face

231

232

Hugging Face estimator for training and deploying transformer models with automatic model hub integration and optimized inference.

233

234

```python { .api }

235

class HuggingFace(Framework):

236

"""

237

Hugging Face estimator for training transformer models on SageMaker.

238

239

Parameters:

240

- entry_point (str): Path to Python training script

241

- transformers_version (str): Transformers version (e.g., "4.36.0")

242

- pytorch_version (str): PyTorch version (e.g., "2.1.0")

243

- py_version (str): Python version ("py39", "py310", "py311")

244

- instance_type (str): EC2 instance type for training

245

- role (str): IAM role ARN

246

- hyperparameters (dict, optional): Training hyperparameters

247

- distribution (dict, optional): Distributed training configuration

248

"""

249

def __init__(self, entry_point: str, transformers_version: str, pytorch_version: str,

250

py_version: str, instance_type: str, role: str, **kwargs): ...

251

252

class HuggingFaceModel(FrameworkModel):

253

"""

254

Hugging Face model for deployment.

255

256

Parameters:

257

- model_data (str, optional): S3 path to model artifacts

258

- role (str): IAM role ARN

259

- transformers_version (str): Transformers version

260

- pytorch_version (str): PyTorch version

261

- env (dict, optional): Environment variables including HF_MODEL_ID

262

"""

263

def __init__(self, role: str, transformers_version: str, pytorch_version: str,

264

model_data: str = None, **kwargs): ...

265

266

class HuggingFacePredictor(Predictor):

267

"""

268

Hugging Face predictor for real-time inference.

269

"""

270

def __init__(self, endpoint_name: str, **kwargs): ...

271

272

class HuggingFaceProcessor(ScriptProcessor):

273

"""

274

Hugging Face processor for data processing jobs.

275

276

Parameters:

277

- transformers_version (str): Transformers version

278

- pytorch_version (str): PyTorch version

279

- instance_type (str): EC2 instance type

280

- instance_count (int): Number of processing instances

281

- role (str): IAM role ARN

282

"""

283

def __init__(self, transformers_version: str, pytorch_version: str,

284

instance_type: str, instance_count: int, role: str, **kwargs): ...

285

286

def get_huggingface_llm_image_uri(backend: str = "huggingface", region: str = None,

287

version: str = None) -> str:

288

"""

289

Get Hugging Face Large Language Model container image URI.

290

291

Parameters:

292

- backend (str): Inference backend ("huggingface", "lmi")

293

- region (str, optional): AWS region

294

- version (str, optional): Container version

295

296

Returns:

297

str: Container image URI

298

"""

299

```

300

301

### MXNet

302

303

MXNet estimator for training and deploying Apache MXNet models with Gluon support.

304

305

```python { .api }

306

class MXNet(Framework):

307

"""

308

MXNet estimator for training MXNet models on SageMaker.

309

310

Parameters:

311

- entry_point (str): Path to Python training script

312

- framework_version (str): MXNet version (e.g., "1.9.0")

313

- py_version (str): Python version ("py38", "py39")

314

- instance_type (str): EC2 instance type for training

315

- role (str): IAM role ARN

316

- distribution (dict, optional): Distributed training configuration

317

"""

318

def __init__(self, entry_point: str, framework_version: str, py_version: str,

319

instance_type: str, role: str, **kwargs): ...

320

321

class MXNetModel(FrameworkModel):

322

"""

323

MXNet model for deployment.

324

325

Parameters:

326

- model_data (str): S3 path to model artifacts

327

- role (str): IAM role ARN

328

- entry_point (str): Path to inference script

329

- framework_version (str): MXNet version

330

"""

331

def __init__(self, model_data: str, role: str, entry_point: str,

332

framework_version: str, **kwargs): ...

333

334

class MXNetPredictor(Predictor):

335

"""

336

MXNet predictor for real-time inference.

337

"""

338

def __init__(self, endpoint_name: str, **kwargs): ...

339

340

class MXNetProcessor(ScriptProcessor):

341

"""

342

MXNet processor for data processing jobs.

343

344

Parameters:

345

- framework_version (str): MXNet version

346

- instance_type (str): EC2 instance type

347

- instance_count (int): Number of processing instances

348

- role (str): IAM role ARN

349

"""

350

def __init__(self, framework_version: str, instance_type: str,

351

instance_count: int, role: str, **kwargs): ...

352

```

353

354

### Training Compiler Configuration

355

356

Optimization configuration for accelerating training performance across frameworks.

357

358

```python { .api }

359

class TrainingCompilerConfig:

360

"""

361

Configuration for SageMaker Training Compiler optimization.

362

363

Parameters:

364

- enabled (bool): Enable training compiler

365

- debug (bool, optional): Enable debug mode

366

"""

367

def __init__(self, enabled: bool = True, debug: bool = False): ...

368

```

369

370

## Usage Examples

371

372

### PyTorch Training and Deployment

373

374

```python

375

from sagemaker.pytorch import PyTorch, PyTorchModel

376

377

# Create PyTorch estimator

378

pytorch_estimator = PyTorch(

379

entry_point="train.py",

380

source_dir="src",

381

role=role,

382

instance_type="ml.p3.2xlarge",

383

instance_count=1,

384

framework_version="2.1.0",

385

py_version="py310",

386

hyperparameters={

387

"epochs": 10,

388

"batch_size": 32

389

}

390

)

391

392

# Train the model

393

pytorch_estimator.fit({"training": training_data_path})

394

395

# Deploy using the estimator

396

predictor = pytorch_estimator.deploy(

397

initial_instance_count=1,

398

instance_type="ml.m5.large"

399

)

400

401

# Or create model and deploy separately

402

pytorch_model = PyTorchModel(

403

model_data=pytorch_estimator.model_data,

404

role=role,

405

entry_point="inference.py",

406

framework_version="2.1.0",

407

py_version="py310"

408

)

409

410

predictor = pytorch_model.deploy(

411

initial_instance_count=1,

412

instance_type="ml.m5.large"

413

)

414

```

415

416

### Hugging Face Training

417

418

```python

419

from sagemaker.huggingface import HuggingFace

420

421

# Create Hugging Face estimator

422

huggingface_estimator = HuggingFace(

423

entry_point="train.py",

424

role=role,

425

instance_type="ml.p3.2xlarge",

426

instance_count=1,

427

transformers_version="4.36.0",

428

pytorch_version="2.1.0",

429

py_version="py310",

430

hyperparameters={

431

"model_name_or_path": "bert-base-uncased",

432

"num_train_epochs": 3,

433

"per_device_train_batch_size": 16

434

}

435

)

436

437

# Train the model

438

huggingface_estimator.fit({"train": train_data, "test": test_data})

439

440

# Deploy for inference

441

predictor = huggingface_estimator.deploy(

442

initial_instance_count=1,

443

instance_type="ml.g4dn.xlarge"

444

)

445

```

446

447

### Distributed Training Example

448

449

```python

450

from sagemaker.pytorch import PyTorch

451

452

# Configure distributed training

453

distribution = {

454

"torch_distributed": {

455

"enabled": True

456

}

457

}

458

459

pytorch_estimator = PyTorch(

460

entry_point="train_distributed.py",

461

role=role,

462

instance_type="ml.p3.8xlarge",

463

instance_count=2, # Multiple instances

464

framework_version="2.1.0",

465

py_version="py310",

466

distribution=distribution,

467

hyperparameters={

468

"backend": "nccl"

469

}

470

)

471

472

pytorch_estimator.fit(training_data_path)

473

```