or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core.mddata.mddistribute.mdimage.mdindex.mdkeras.mdmath.mdnn.mdsaved-model.md

saved-model.mddocs/

0

# Model Management

1

2

Complete model serialization, checkpointing, and deployment utilities for production and inference. These operations provide comprehensive model lifecycle management capabilities.

3

4

## Capabilities

5

6

### Model Saving and Loading

7

8

Save and load complete models with all weights, architecture, and training configuration.

9

10

```python { .api }

11

def save(obj, export_dir, signatures=None, options=None):

12

"""

13

Exports a tf.Module (and subclasses) obj to SavedModel format.

14

15

Parameters:

16

- obj: A trackable object (e.g. tf.Module or tf.keras.Model) to export

17

- export_dir: A directory in which to write the SavedModel

18

- signatures: Optional, either a tf.function with an input signature specified or a dictionary

19

- options: Optional, tf.saved_model.SaveOptions object that specifies options for saving

20

"""

21

22

def load(export_dir, tags=None, options=None):

23

"""

24

Load a SavedModel from export_dir.

25

26

Parameters:

27

- export_dir: The SavedModel directory to load from

28

- tags: A tag or sequence of tags identifying the MetaGraph to load

29

- options: Optional, tf.saved_model.LoadOptions object that specifies options for loading

30

31

Returns:

32

A trackable object with a save method

33

"""

34

35

def contains_saved_model(export_dir):

36

"""

37

Checks whether the provided export directory could contain a SavedModel.

38

39

Parameters:

40

- export_dir: Absolute or relative path to a directory containing the SavedModel

41

42

Returns:

43

True if the export directory contains SavedModel files, False otherwise

44

"""

45

```

46

47

### Checkpointing

48

49

Save and restore model weights and training state for resuming training.

50

51

```python { .api }

52

class Checkpoint:

53

"""

54

Groups trackable objects, saving and restoring them.

55

56

Methods:

57

- save(file_prefix): Saves a training checkpoint and provides a context manager

58

- restore(save_path): Restore a training checkpoint

59

- read(save_path): Returns CheckpointReader for checkpoint inspection

60

"""

61

62

def __init__(self, **kwargs):

63

"""

64

Groups trackable objects, saving and restoring them.

65

66

Parameters:

67

- **kwargs: Keyword arguments are set as attributes of this object, and are saved with the checkpoint

68

"""

69

70

def save(self, file_prefix, session=None):

71

"""

72

Saves a training checkpoint and provides a context manager.

73

74

Parameters:

75

- file_prefix: A prefix to use for the checkpoint filenames

76

- session: The session to evaluate variables in. Ignored when executing eagerly

77

78

Returns:

79

The full path to the checkpoint

80

"""

81

82

def restore(self, save_path):

83

"""

84

Restore a training checkpoint.

85

86

Parameters:

87

- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint

88

89

Returns:

90

A load status object, which can be used to make assertions about the status of a checkpoint restoration

91

"""

92

93

def read(self, save_path):

94

"""

95

Returns a CheckpointReader for the checkpoint.

96

97

Parameters:

98

- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint

99

100

Returns:

101

A CheckpointReader object

102

"""

103

104

class CheckpointManager:

105

"""

106

Deletes old checkpoints.

107

108

Methods:

109

- save(checkpoint_number): Creates a new checkpoint

110

"""

111

112

def __init__(self, checkpoint, directory, max_to_keep=5, keep_checkpoint_every_n_hours=None,

113

checkpoint_name="ckpt", step_counter=None, checkpoint_interval=None,

114

init_fn=None):

115

"""

116

Deletes old checkpoints.

117

118

Parameters:

119

- checkpoint: The tf.train.Checkpoint instance to save and manage checkpoints for

120

- directory: The path to a directory in which to write checkpoints

121

- max_to_keep: An integer, the number of checkpoints to keep

122

- keep_checkpoint_every_n_hours: Upon removal, keep checkpoints every N hours

123

- checkpoint_name: Custom name for the checkpoint file

124

- step_counter: A tf.Variable instance for checking the current step counter value

125

- checkpoint_interval: An integer, indicates that keep_checkpoint_every_n_hours should be based on checkpoints saved every checkpoint_interval steps

126

- init_fn: Callable. Function executed the first time a checkpoint is saved

127

"""

128

129

def save(self, checkpoint_number=None, check_interval=True):

130

"""

131

Creates a new checkpoint and manages deletion of old checkpoints.

132

133

Parameters:

134

- checkpoint_number: An optional integer, or an integer-dtype Variable or Tensor, used to number the checkpoint

135

- check_interval: An optional boolean. The default behaviour is that checkpoint_interval is ignored when checkpoint_number is provided

136

137

Returns:

138

The path to the new checkpoint. It is also recorded in the checkpoints and latest_checkpoint properties

139

"""

140

```

141

142

### Checkpoint Utilities

143

144

Utility functions for working with checkpoints.

145

146

```python { .api }

147

def list_variables(checkpoint_dir):

148

"""

149

Returns list of all variables in the checkpoint.

150

151

Parameters:

152

- checkpoint_dir: Directory with checkpoint file or path to checkpoint

153

154

Returns:

155

List of tuples (name, shape) for all variables in the checkpoint

156

"""

157

158

def load_checkpoint(checkpoint_dir):

159

"""

160

Returns CheckpointReader for checkpoint found in checkpoint_dir.

161

162

Parameters:

163

- checkpoint_dir: Directory with checkpoint file or path to checkpoint

164

165

Returns:

166

CheckpointReader instance

167

"""

168

169

def load_variable(checkpoint_dir, name):

170

"""

171

Returns the tensor value of the given variable in the checkpoint.

172

173

Parameters:

174

- checkpoint_dir: Directory with checkpoint file or path to checkpoint

175

- name: Name of the variable to return

176

177

Returns:

178

A numpy ndarray with a copy of the value of this variable

179

"""

180

181

def latest_checkpoint(checkpoint_dir, latest_filename=None):

182

"""

183

Finds the filename of latest saved checkpoint file.

184

185

Parameters:

186

- checkpoint_dir: Directory where the variables were saved

187

- latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames

188

189

Returns:

190

The full path to the latest checkpoint or None if no checkpoint was found

191

"""

192

```

193

194

### SavedModel Utilities

195

196

Additional utilities for working with SavedModel format.

197

198

```python { .api }

199

class SaveOptions:

200

"""

201

Options for saving to SavedModel.

202

203

Parameters:

204

- namespace_whitelist: List of strings containing op namespaces to whitelist when saving a model

205

- save_debug_info: Boolean indicating whether debug information is saved

206

- function_aliases: Optional dictionary of string -> string of function aliases

207

- experimental_io_device: string. Applies in a distributed setting

208

- experimental_variable_policy: The policy to apply to variables when saving

209

"""

210

211

class LoadOptions:

212

"""

213

Options for loading a SavedModel.

214

215

Parameters:

216

- allow_partial_checkpoint: Boolean. Defaults to False. When enabled, allows the SavedModel checkpoint to be missing variables

217

- experimental_io_device: string. Loads SavedModel and variables on the specified device

218

- experimental_skip_checkpoint: boolean. If True, the checkpoint will not be loaded, and the SavedModel will be loaded with randomly initialized variable values

219

"""

220

221

class Asset:

222

"""

223

Represents a file asset to copy into the SavedModel.

224

225

Parameters:

226

- path: A path, or a 0-D tf.string Tensor with path to the asset

227

"""

228

```

229

230

## Usage Examples

231

232

```python

233

import tensorflow as tf

234

import os

235

236

# Create a simple model

237

model = tf.keras.Sequential([

238

tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),

239

tf.keras.layers.Dense(32, activation='relu'),

240

tf.keras.layers.Dense(1, activation='sigmoid')

241

])

242

243

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

244

245

# Save entire model to SavedModel format

246

tf.saved_model.save(model, 'my_saved_model')

247

248

# Load the saved model

249

loaded_model = tf.saved_model.load('my_saved_model')

250

251

# For Keras models, use keras save/load for full functionality

252

model.save('my_keras_model.h5')

253

loaded_keras_model = tf.keras.models.load_model('my_keras_model.h5')

254

255

# Checkpoint example

256

checkpoint_dir = './training_checkpoints'

257

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

258

259

# Create checkpoint object

260

checkpoint = tf.train.Checkpoint(optimizer=tf.keras.optimizers.Adam(),

261

model=model)

262

263

# Save checkpoint

264

checkpoint.save(file_prefix=checkpoint_prefix)

265

266

# Restore from checkpoint

267

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

268

269

# Using CheckpointManager for automatic cleanup

270

manager = tf.train.CheckpointManager(

271

checkpoint, directory=checkpoint_dir, max_to_keep=3

272

)

273

274

# Save with automatic cleanup

275

save_path = manager.save()

276

print(f"Saved checkpoint for step {step}: {save_path}")

277

278

# Training loop with checkpointing

279

optimizer = tf.keras.optimizers.Adam()

280

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

281

manager = tf.train.CheckpointManager(checkpoint, './checkpoints', max_to_keep=3)

282

283

# Restore if checkpoint exists

284

checkpoint.restore(manager.latest_checkpoint)

285

if manager.latest_checkpoint:

286

print(f"Restored from {manager.latest_checkpoint}")

287

else:

288

print("Initializing from scratch.")

289

290

# Training step function

291

@tf.function

292

def train_step(x, y):

293

with tf.GradientTape() as tape:

294

predictions = model(x, training=True)

295

loss = tf.keras.losses.binary_crossentropy(y, predictions)

296

297

gradients = tape.gradient(loss, model.trainable_variables)

298

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

299

300

return loss

301

302

# Training loop

303

for epoch in range(10):

304

# Training code here...

305

# x_batch, y_batch = get_batch()

306

# loss = train_step(x_batch, y_batch)

307

308

# Save checkpoint every few epochs

309

if epoch % 2 == 0:

310

save_path = manager.save()

311

print(f"Saved checkpoint for epoch {epoch}: {save_path}")

312

313

# Inspect checkpoint contents

314

checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)

315

if checkpoint_path:

316

variables = tf.train.list_variables(checkpoint_path)

317

for name, shape in variables:

318

print(f"Variable: {name}, Shape: {shape}")

319

320

# Load specific variable

321

specific_var = tf.train.load_variable(checkpoint_path, 'model/dense/kernel/.ATTRIBUTES/VARIABLE_VALUE')

322

print(f"Loaded variable shape: {specific_var.shape}")

323

324

# Check if directory contains SavedModel

325

if tf.saved_model.contains_saved_model('my_saved_model'):

326

print("Directory contains a valid SavedModel")

327

328

# Advanced SavedModel with custom signatures

329

@tf.function(input_signature=[tf.TensorSpec(shape=[None, 10], dtype=tf.float32)])

330

def inference_func(x):

331

return model(x)

332

333

# Save with custom signature

334

tf.saved_model.save(

335

model,

336

'model_with_signature',

337

signatures={'serving_default': inference_func}

338

)

339

```