0
# Model Management
1
2
Complete model serialization, checkpointing, and deployment utilities for production and inference. These operations provide comprehensive model lifecycle management capabilities.
3
4
## Capabilities
5
6
### Model Saving and Loading
7
8
Save and load complete models with all weights, architecture, and training configuration.
9
10
```python { .api }
11
def save(obj, export_dir, signatures=None, options=None):
12
"""
13
Exports a tf.Module (and subclasses) obj to SavedModel format.
14
15
Parameters:
16
- obj: A trackable object (e.g. tf.Module or tf.keras.Model) to export
17
- export_dir: A directory in which to write the SavedModel
18
- signatures: Optional, either a tf.function with an input signature specified or a dictionary
19
- options: Optional, tf.saved_model.SaveOptions object that specifies options for saving
20
"""
21
22
def load(export_dir, tags=None, options=None):
23
"""
24
Load a SavedModel from export_dir.
25
26
Parameters:
27
- export_dir: The SavedModel directory to load from
28
- tags: A tag or sequence of tags identifying the MetaGraph to load
29
- options: Optional, tf.saved_model.LoadOptions object that specifies options for loading
30
31
Returns:
32
A trackable object with a save method
33
"""
34
35
def contains_saved_model(export_dir):
36
"""
37
Checks whether the provided export directory could contain a SavedModel.
38
39
Parameters:
40
- export_dir: Absolute or relative path to a directory containing the SavedModel
41
42
Returns:
43
True if the export directory contains SavedModel files, False otherwise
44
"""
45
```
46
47
### Checkpointing
48
49
Save and restore model weights and training state for resuming training.
50
51
```python { .api }
52
class Checkpoint:
53
"""
54
Groups trackable objects, saving and restoring them.
55
56
Methods:
57
- save(file_prefix): Saves a training checkpoint and provides a context manager
58
- restore(save_path): Restore a training checkpoint
59
- read(save_path): Returns CheckpointReader for checkpoint inspection
60
"""
61
62
def __init__(self, **kwargs):
63
"""
64
Groups trackable objects, saving and restoring them.
65
66
Parameters:
67
- **kwargs: Keyword arguments are set as attributes of this object, and are saved with the checkpoint
68
"""
69
70
def save(self, file_prefix, session=None):
71
"""
72
Saves a training checkpoint and provides a context manager.
73
74
Parameters:
75
- file_prefix: A prefix to use for the checkpoint filenames
76
- session: The session to evaluate variables in. Ignored when executing eagerly
77
78
Returns:
79
The full path to the checkpoint
80
"""
81
82
def restore(self, save_path):
83
"""
84
Restore a training checkpoint.
85
86
Parameters:
87
- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint
88
89
Returns:
90
A load status object, which can be used to make assertions about the status of a checkpoint restoration
91
"""
92
93
def read(self, save_path):
94
"""
95
Returns a CheckpointReader for the checkpoint.
96
97
Parameters:
98
- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint
99
100
Returns:
101
A CheckpointReader object
102
"""
103
104
class CheckpointManager:
105
"""
106
Deletes old checkpoints.
107
108
Methods:
109
- save(checkpoint_number): Creates a new checkpoint
110
"""
111
112
def __init__(self, checkpoint, directory, max_to_keep=5, keep_checkpoint_every_n_hours=None,
113
checkpoint_name="ckpt", step_counter=None, checkpoint_interval=None,
114
init_fn=None):
115
"""
116
Deletes old checkpoints.
117
118
Parameters:
119
- checkpoint: The tf.train.Checkpoint instance to save and manage checkpoints for
120
- directory: The path to a directory in which to write checkpoints
121
- max_to_keep: An integer, the number of checkpoints to keep
122
- keep_checkpoint_every_n_hours: Upon removal, keep checkpoints every N hours
123
- checkpoint_name: Custom name for the checkpoint file
124
- step_counter: A tf.Variable instance for checking the current step counter value
125
- checkpoint_interval: An integer, indicates that keep_checkpoint_every_n_hours should be based on checkpoints saved every checkpoint_interval steps
126
- init_fn: Callable. Function executed the first time a checkpoint is saved
127
"""
128
129
def save(self, checkpoint_number=None, check_interval=True):
130
"""
131
Creates a new checkpoint and manages deletion of old checkpoints.
132
133
Parameters:
134
- checkpoint_number: An optional integer, or an integer-dtype Variable or Tensor, used to number the checkpoint
135
- check_interval: An optional boolean. The default behaviour is that checkpoint_interval is ignored when checkpoint_number is provided
136
137
Returns:
138
The path to the new checkpoint. It is also recorded in the checkpoints and latest_checkpoint properties
139
"""
140
```
141
142
### Checkpoint Utilities
143
144
Utility functions for working with checkpoints.
145
146
```python { .api }
147
def list_variables(checkpoint_dir):
148
"""
149
Returns list of all variables in the checkpoint.
150
151
Parameters:
152
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
153
154
Returns:
155
List of tuples (name, shape) for all variables in the checkpoint
156
"""
157
158
def load_checkpoint(checkpoint_dir):
159
"""
160
Returns CheckpointReader for checkpoint found in checkpoint_dir.
161
162
Parameters:
163
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
164
165
Returns:
166
CheckpointReader instance
167
"""
168
169
def load_variable(checkpoint_dir, name):
170
"""
171
Returns the tensor value of the given variable in the checkpoint.
172
173
Parameters:
174
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
175
- name: Name of the variable to return
176
177
Returns:
178
A numpy ndarray with a copy of the value of this variable
179
"""
180
181
def latest_checkpoint(checkpoint_dir, latest_filename=None):
182
"""
183
Finds the filename of latest saved checkpoint file.
184
185
Parameters:
186
- checkpoint_dir: Directory where the variables were saved
187
- latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames
188
189
Returns:
190
The full path to the latest checkpoint or None if no checkpoint was found
191
"""
192
```
193
194
### SavedModel Utilities
195
196
Additional utilities for working with SavedModel format.
197
198
```python { .api }
199
class SaveOptions:
200
"""
201
Options for saving to SavedModel.
202
203
Parameters:
204
- namespace_whitelist: List of strings containing op namespaces to whitelist when saving a model
205
- save_debug_info: Boolean indicating whether debug information is saved
206
- function_aliases: Optional dictionary of string -> string of function aliases
207
- experimental_io_device: string. Applies in a distributed setting
208
- experimental_variable_policy: The policy to apply to variables when saving
209
"""
210
211
class LoadOptions:
212
"""
213
Options for loading a SavedModel.
214
215
Parameters:
216
- allow_partial_checkpoint: Boolean. Defaults to False. When enabled, allows the SavedModel checkpoint to be missing variables
217
- experimental_io_device: string. Loads SavedModel and variables on the specified device
218
- experimental_skip_checkpoint: boolean. If True, the checkpoint will not be loaded, and the SavedModel will be loaded with randomly initialized variable values
219
"""
220
221
class Asset:
222
"""
223
Represents a file asset to copy into the SavedModel.
224
225
Parameters:
226
- path: A path, or a 0-D tf.string Tensor with path to the asset
227
"""
228
```
229
230
## Usage Examples
231
232
```python
233
import tensorflow as tf
234
import os
235
236
# Create a simple model
237
model = tf.keras.Sequential([
238
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
239
tf.keras.layers.Dense(32, activation='relu'),
240
tf.keras.layers.Dense(1, activation='sigmoid')
241
])
242
243
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
244
245
# Save entire model to SavedModel format
246
tf.saved_model.save(model, 'my_saved_model')
247
248
# Load the saved model
249
loaded_model = tf.saved_model.load('my_saved_model')
250
251
# For Keras models, use keras save/load for full functionality
252
model.save('my_keras_model.h5')
253
loaded_keras_model = tf.keras.models.load_model('my_keras_model.h5')
254
255
# Checkpoint example
256
checkpoint_dir = './training_checkpoints'
257
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
258
259
# Create checkpoint object
260
checkpoint = tf.train.Checkpoint(optimizer=tf.keras.optimizers.Adam(),
261
model=model)
262
263
# Save checkpoint
264
checkpoint.save(file_prefix=checkpoint_prefix)
265
266
# Restore from checkpoint
267
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
268
269
# Using CheckpointManager for automatic cleanup
270
manager = tf.train.CheckpointManager(
271
checkpoint, directory=checkpoint_dir, max_to_keep=3
272
)
273
274
# Save with automatic cleanup
275
save_path = manager.save()
276
print(f"Saved checkpoint for step {step}: {save_path}")
277
278
# Training loop with checkpointing
279
optimizer = tf.keras.optimizers.Adam()
280
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
281
manager = tf.train.CheckpointManager(checkpoint, './checkpoints', max_to_keep=3)
282
283
# Restore if checkpoint exists
284
checkpoint.restore(manager.latest_checkpoint)
285
if manager.latest_checkpoint:
286
print(f"Restored from {manager.latest_checkpoint}")
287
else:
288
print("Initializing from scratch.")
289
290
# Training step function
291
@tf.function
292
def train_step(x, y):
293
with tf.GradientTape() as tape:
294
predictions = model(x, training=True)
295
loss = tf.keras.losses.binary_crossentropy(y, predictions)
296
297
gradients = tape.gradient(loss, model.trainable_variables)
298
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
299
300
return loss
301
302
# Training loop
303
for epoch in range(10):
304
# Training code here...
305
# x_batch, y_batch = get_batch()
306
# loss = train_step(x_batch, y_batch)
307
308
# Save checkpoint every few epochs
309
if epoch % 2 == 0:
310
save_path = manager.save()
311
print(f"Saved checkpoint for epoch {epoch}: {save_path}")
312
313
# Inspect checkpoint contents
314
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
315
if checkpoint_path:
316
variables = tf.train.list_variables(checkpoint_path)
317
for name, shape in variables:
318
print(f"Variable: {name}, Shape: {shape}")
319
320
# Load specific variable
321
specific_var = tf.train.load_variable(checkpoint_path, 'model/dense/kernel/.ATTRIBUTES/VARIABLE_VALUE')
322
print(f"Loaded variable shape: {specific_var.shape}")
323
324
# Check if directory contains SavedModel
325
if tf.saved_model.contains_saved_model('my_saved_model'):
326
print("Directory contains a valid SavedModel")
327
328
# Advanced SavedModel with custom signatures
329
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 10], dtype=tf.float32)])
330
def inference_func(x):
331
return model(x)
332
333
# Save with custom signature
334
tf.saved_model.save(
335
model,
336
'model_with_signature',
337
signatures={'serving_default': inference_func}
338
)
339
```