or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

amazon-algorithms.mdautoml.mdcore-training.mddata-processing.mddebugging-profiling.mdexperiments.mdframework-training.mdhyperparameter-tuning.mdindex.mdmodel-monitoring.mdmodel-serving.mdremote-functions.md

core-training.mddocs/

0

# Core Training and Model Management

1

2

Fundamental classes and functions for training models and managing deployments in SageMaker. These core components provide the foundation for all ML workflows including training, deployment, and inference.

3

4

## Capabilities

5

6

### Estimator Base Classes

7

8

Core estimator classes that provide the foundation for all training workflows, handling AWS service integration, resource management, and deployment.

9

10

```python { .api }

11

class Estimator:

12

"""

13

Base class for all SageMaker estimators.

14

15

Parameters:

16

- image_uri (str): Docker image URI for training

17

- role (str): IAM role ARN with SageMaker permissions

18

- instance_count (int): Number of training instances

19

- instance_type (str): EC2 instance type for training

20

- output_path (str, optional): S3 path for model artifacts

21

- sagemaker_session (Session, optional): SageMaker session

22

- hyperparameters (dict, optional): Algorithm hyperparameters

23

- environment (dict, optional): Environment variables

24

- max_run (int, optional): Maximum training time in seconds

25

- input_mode (str, optional): Training input mode ('File' or 'Pipe')

26

- vpc_config (dict, optional): VPC configuration

27

- metric_definitions (list, optional): Custom metric definitions

28

"""

29

def __init__(self, image_uri: str, role: str, instance_count: int,

30

instance_type: str, output_path: str = None,

31

sagemaker_session: Session = None, **kwargs): ...

32

33

def fit(self, inputs, wait: bool = True, logs: bool = True,

34

job_name: str = None, experiment_config: dict = None): ...

35

36

def deploy(self, initial_instance_count: int, instance_type: str,

37

serializer: BaseSerializer = None, deserializer: BaseDeserializer = None,

38

accelerator_type: str = None, endpoint_name: str = None,

39

inference_component_name: str = None, **kwargs) -> Predictor: ...

40

41

def create_model(self, vpc_config_override: dict = None, **kwargs) -> Model: ...

42

43

def delete_endpoint(self, endpoint_name: str = None): ...

44

45

class Framework(Estimator):

46

"""

47

Base class for framework-specific estimators.

48

49

Parameters:

50

- entry_point (str): Path to training script

51

- source_dir (str, optional): Directory containing training code

52

- dependencies (list, optional): List of additional dependencies

53

- code_location (str, optional): S3 location for training code

54

"""

55

def __init__(self, entry_point: str, source_dir: str = None,

56

dependencies: list = None, **kwargs): ...

57

```

58

59

### Model Management

60

61

Model classes for deploying trained models, managing model artifacts, and creating inference endpoints.

62

63

```python { .api }

64

class Model:

65

"""

66

Model class for deploying trained models to SageMaker endpoints.

67

68

Parameters:

69

- image_uri (str): Docker image URI for inference

70

- model_data (str): S3 path to model artifacts

71

- role (str): IAM role ARN with SageMaker permissions

72

- predictor_cls (type, optional): Predictor class for deployment

73

- env (dict, optional): Environment variables

74

- name (str, optional): Model name

75

- vpc_config (dict, optional): VPC configuration

76

- sagemaker_session (Session, optional): SageMaker session

77

"""

78

def __init__(self, image_uri: str, model_data: str, role: str,

79

predictor_cls: type = None, env: dict = None, **kwargs): ...

80

81

def deploy(self, initial_instance_count: int, instance_type: str,

82

serializer: BaseSerializer = None, deserializer: BaseDeserializer = None,

83

accelerator_type: str = None, endpoint_name: str = None,

84

inference_component_name: str = None, **kwargs) -> Predictor: ...

85

86

def create(self, instance_type: str = None, accelerator_type: str = None): ...

87

88

def delete_model(self): ...

89

90

def register(self, content_types: list, response_types: list,

91

inference_instances: list = None, transform_instances: list = None,

92

model_package_group_name: str = None, **kwargs): ...

93

94

class ModelPackage:

95

"""

96

Model package class for versioned model management and deployment.

97

98

Parameters:

99

- role (str): IAM role ARN

100

- model_data (str, optional): S3 path to model artifacts

101

- image_uri (str, optional): Docker image URI

102

- model_package_arn (str, optional): Existing model package ARN

103

"""

104

def __init__(self, role: str, model_data: str = None, image_uri: str = None,

105

model_package_arn: str = None, **kwargs): ...

106

107

def deploy(self, initial_instance_count: int, instance_type: str, **kwargs) -> Predictor: ...

108

109

class PipelineModel:

110

"""

111

Pipeline model for chaining multiple models in sequence.

112

113

Parameters:

114

- name (str): Pipeline model name

115

- role (str): IAM role ARN

116

- models (list): List of Model objects to chain

117

"""

118

def __init__(self, name: str, role: str, models: list, **kwargs): ...

119

120

def deploy(self, initial_instance_count: int, instance_type: str, **kwargs) -> Predictor: ...

121

```

122

123

### Prediction and Inference

124

125

Predictor classes for making real-time and batch predictions against deployed models.

126

127

```python { .api }

128

class Predictor:

129

"""

130

Base predictor class for real-time inference.

131

132

Parameters:

133

- endpoint_name (str): SageMaker endpoint name

134

- sagemaker_session (Session, optional): SageMaker session

135

- serializer (BaseSerializer, optional): Input serializer

136

- deserializer (BaseDeserializer, optional): Output deserializer

137

"""

138

def __init__(self, endpoint_name: str, sagemaker_session: Session = None,

139

serializer: BaseSerializer = None, deserializer: BaseDeserializer = None): ...

140

141

def predict(self, data, initial_args: dict = None, target_model: str = None,

142

target_variant: str = None, inference_id: str = None): ...

143

144

def update_endpoint(self, initial_instance_count: int = None,

145

instance_type: str = None, **kwargs): ...

146

147

def delete_endpoint(self, delete_endpoint_config: bool = True): ...

148

149

def delete_model(self): ...

150

151

def enable_data_capture(self, sampling_percentage: int = 20,

152

capture_options: list = None): ...

153

154

def disable_data_capture(): ...

155

156

class AsyncPredictor:

157

"""

158

Async predictor for asynchronous inference.

159

160

Parameters:

161

- predictor (Predictor): Base predictor instance

162

- name (str, optional): Async inference name

163

"""

164

def __init__(self, predictor: Predictor, name: str = None): ...

165

166

def predict_async(self, input_path: str, initial_args: dict = None,

167

inference_id: str = None) -> str: ...

168

169

def describe_async_inference_result(self, result_path: str) -> dict: ...

170

```

171

172

### Session Management

173

174

Session classes for managing AWS credentials, regions, and SageMaker service configurations.

175

176

```python { .api }

177

class Session:

178

"""

179

SageMaker session for managing service interactions.

180

181

Parameters:

182

- boto_session (boto3.Session, optional): Boto3 session

183

- sagemaker_client (boto3.Client, optional): SageMaker client

184

- sagemaker_runtime_client (boto3.Client, optional): SageMaker Runtime client

185

- default_bucket (str, optional): Default S3 bucket name

186

- s3_resource (boto3.Resource, optional): S3 resource

187

- settings (SessionSettings, optional): Session settings

188

"""

189

def __init__(self, boto_session: 'boto3.Session' = None,

190

sagemaker_client: 'boto3.Client' = None,

191

sagemaker_runtime_client: 'boto3.Client' = None, **kwargs): ...

192

193

def upload_data(self, path: str, bucket: str = None, key_prefix: str = None,

194

callback: callable = None, extra_args: dict = None) -> str: ...

195

196

def download_data(self, path: str, bucket: str, key_prefix: str,

197

extra_args: dict = None): ...

198

199

def create_training_job(self, **kwargs) -> dict: ...

200

201

def create_model(self, **kwargs) -> dict: ...

202

203

def create_endpoint_config(self, **kwargs) -> dict: ...

204

205

def create_endpoint(self, **kwargs) -> dict: ...

206

207

def wait_for_training_job(self, job_name: str, poll: int = 5): ...

208

209

def wait_for_endpoint(self, endpoint_name: str, poll: int = 30): ...

210

211

def default_bucket(self) -> str: ...

212

213

def delete_endpoint(self, endpoint_name: str): ...

214

215

def delete_endpoint_config(self, endpoint_config_name: str): ...

216

217

def delete_model(self, model_name: str): ...

218

219

class LocalSession(Session):

220

"""

221

Local session for local development and testing.

222

"""

223

def __init__(self, **kwargs): ...

224

225

def get_execution_role() -> str:

226

"""

227

Get the IAM execution role from the SageMaker notebook instance or environment.

228

229

Returns:

230

str: IAM role ARN

231

232

Raises:

233

ValueError: If role cannot be determined

234

"""

235

236

def container_def(image_uri: str, model_data_url: str = None, env: dict = None,

237

container_hostname: str = None, image_config: dict = None) -> dict:

238

"""

239

Create container definition for multi-model endpoints.

240

241

Parameters:

242

- image_uri (str): Docker image URI

243

- model_data_url (str, optional): S3 path to model artifacts

244

- env (dict, optional): Environment variables

245

- container_hostname (str, optional): Container hostname

246

- image_config (dict, optional): Image configuration

247

248

Returns:

249

dict: Container definition

250

"""

251

252

def pipeline_container_def(models: list, instance_type: str = None) -> list:

253

"""

254

Create container definitions for pipeline models.

255

256

Parameters:

257

- models (list): List of Model objects

258

- instance_type (str, optional): Instance type

259

260

Returns:

261

list: List of container definitions

262

"""

263

264

def production_variant(model_name: str, instance_type: str, initial_instance_count: int = 1,

265

variant_name: str = "AllTraffic", initial_weight: int = 1,

266

accelerator_type: str = None, serverless_inference_config: dict = None) -> dict:

267

"""

268

Create production variant configuration for endpoints.

269

270

Parameters:

271

- model_name (str): SageMaker model name

272

- instance_type (str): EC2 instance type

273

- initial_instance_count (int): Initial instance count

274

- variant_name (str): Variant name

275

- initial_weight (int): Traffic weight

276

- accelerator_type (str, optional): Accelerator type

277

- serverless_inference_config (dict, optional): Serverless config

278

279

Returns:

280

dict: Production variant configuration

281

"""

282

283

def get_model_package_args(content_types: list, response_types: list,

284

inference_instances: list = None, transform_instances: list = None) -> dict:

285

"""

286

Get model package arguments for registration.

287

288

Parameters:

289

- content_types (list): Supported content types

290

- response_types (list): Supported response types

291

- inference_instances (list, optional): Supported inference instances

292

- transform_instances (list, optional): Supported transform instances

293

294

Returns:

295

dict: Model package arguments

296

"""

297

```

298

299

## Usage Examples

300

301

### Basic Training and Deployment

302

303

```python

304

import sagemaker

305

from sagemaker import Estimator, Session, get_execution_role

306

307

# Set up session and role

308

session = Session()

309

role = get_execution_role()

310

311

# Create custom estimator

312

estimator = Estimator(

313

image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/my-algorithm:latest",

314

role=role,

315

instance_count=1,

316

instance_type="ml.m5.large",

317

output_path="s3://my-bucket/model-artifacts"

318

)

319

320

# Train the model

321

estimator.fit({"training": "s3://my-bucket/training-data"})

322

323

# Deploy the model

324

predictor = estimator.deploy(

325

initial_instance_count=1,

326

instance_type="ml.m5.large"

327

)

328

329

# Make predictions

330

result = predictor.predict(test_data)

331

332

# Clean up

333

predictor.delete_endpoint()

334

```

335

336

### Model Registration and Deployment

337

338

```python

339

from sagemaker import Model, ModelPackage

340

341

# Create model from artifacts

342

model = Model(

343

image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/inference:latest",

344

model_data="s3://my-bucket/model.tar.gz",

345

role=role

346

)

347

348

# Register model package

349

model_package = model.register(

350

content_types=["application/json"],

351

response_types=["application/json"],

352

inference_instances=["ml.m5.large"],

353

model_package_group_name="my-model-group"

354

)

355

356

# Deploy from model package

357

predictor = model_package.deploy(

358

initial_instance_count=1,

359

instance_type="ml.m5.large"

360

)

361

```