0
# Configuration and Persistence
1
2
Configuration objects and data structures for TensorFlow sessions, saved models, and model metadata management. These types are essential for configuring TensorFlow execution and saving/loading trained models.
3
4
## Capabilities
5
6
### ConfigProto
7
8
Session configuration parameters controlling TensorFlow execution behavior.
9
10
```java { .api }
11
/**
12
* Session configuration parameters
13
*/
14
class ConfigProto {
15
/** Get device type to count mapping */
16
Map<String, Integer> getDeviceCountMap();
17
18
/** Get number of threads for intra-op parallelism */
19
int getIntraOpParallelismThreads();
20
21
/** Get number of threads for inter-op parallelism */
22
int getInterOpParallelismThreads();
23
24
/** Check if soft placement is enabled */
25
boolean getAllowSoftPlacement();
26
27
/** Check if device placement logging is enabled */
28
boolean getLogDevicePlacement();
29
30
/** Get GPU configuration options */
31
GPUOptions getGpuOptions();
32
33
/** Check if CUDA malloc is allowed to grow */
34
boolean getAllowSoftPlacement();
35
36
/** Get graph optimization options */
37
GraphOptions getGraphOptions();
38
39
/** Get operation timeout in milliseconds */
40
long getOperationTimeoutInMs();
41
42
/** Get RPC options for distributed execution */
43
RPCOptions getRpcOptions();
44
45
/** Get cluster configuration */
46
ClusterDef getClusterDef();
47
48
/** Check if XLA JIT compilation is enabled globally */
49
boolean getIsolateSessionState();
50
51
/** Create a new builder for constructing ConfigProto */
52
static Builder newBuilder();
53
54
/** Builder for constructing ConfigProto instances */
55
static class Builder {
56
Builder putDeviceCount(String key, int value);
57
Builder setIntraOpParallelismThreads(int threads);
58
Builder setInterOpParallelismThreads(int threads);
59
Builder setAllowSoftPlacement(boolean allow);
60
Builder setLogDevicePlacement(boolean log);
61
Builder setGpuOptions(GPUOptions options);
62
Builder setGraphOptions(GraphOptions options);
63
Builder setOperationTimeoutInMs(long timeout);
64
Builder setRpcOptions(RPCOptions options);
65
Builder setClusterDef(ClusterDef cluster);
66
ConfigProto build();
67
}
68
}
69
```
70
71
**Usage Examples:**
72
73
```java
74
import org.tensorflow.framework.*;
75
76
// Create a basic session configuration
77
ConfigProto config = ConfigProto.newBuilder()
78
.setAllowSoftPlacement(true)
79
.setLogDevicePlacement(false)
80
.setIntraOpParallelismThreads(4)
81
.setInterOpParallelismThreads(2)
82
.putDeviceCount("GPU", 1)
83
.putDeviceCount("CPU", 1)
84
.build();
85
86
// Configure GPU options
87
ConfigProto gpuConfig = ConfigProto.newBuilder()
88
.setGpuOptions(GPUOptions.newBuilder()
89
.setAllowGrowth(true)
90
.setPerProcessGpuMemoryFraction(0.8)
91
.addVisibleDeviceList("0") // Use only GPU 0
92
.build())
93
.build();
94
95
// Configure graph optimization
96
ConfigProto optimizedConfig = ConfigProto.newBuilder()
97
.setGraphOptions(GraphOptions.newBuilder()
98
.setEnableRecvScheduling(true)
99
.setOptimizerOptions(OptimizerOptions.newBuilder()
100
.setDoCommonSubexpressionElimination(true)
101
.setDoConstantFolding(true)
102
.setDoFunctionInlining(true)
103
.build())
104
.build())
105
.build();
106
```
107
108
### GPUOptions
109
110
GPU-specific configuration options.
111
112
```java { .api }
113
/**
114
* GPU configuration options
115
*/
116
class GPUOptions {
117
/** Get GPU memory fraction (0.0 to 1.0) */
118
double getPerProcessGpuMemoryFraction();
119
120
/** Check if GPU memory growth is allowed */
121
boolean getAllowGrowth();
122
123
/** Get GPU memory allocator type */
124
String getAllocatorType();
125
126
/** Get deferred deletion bytes */
127
long getDeferredDeletionBytes();
128
129
/** Get list of visible GPU device IDs */
130
List<String> getVisibleDeviceListList();
131
132
/** Get polling inactive delay in microseconds */
133
int getPollingInactiveDelayMsecs();
134
135
/** Check if force GPU compatible is enabled */
136
boolean getForceGpuCompatible();
137
138
/** Get experimental options */
139
Experimental getExperimental();
140
141
static class Experimental {
142
List<String> getVirtualDevicesList();
143
boolean getUseUnifiedMemory();
144
int getNumDevToDevCopyStreams();
145
String getCollectiveRingOrder();
146
int getTimestampedAllocator();
147
String getKernelTrackerMaxInterval();
148
int getKernelTrackerMaxBytes();
149
int getKernelTrackerMaxPending();
150
boolean getInternalFragmentationFraction();
151
boolean getUseMultiDeviceIterator();
152
}
153
}
154
```
155
156
### GraphOptions
157
158
Graph optimization configuration options.
159
160
```java { .api }
161
/**
162
* Graph optimization configuration options
163
*/
164
class GraphOptions {
165
/** Check if receiver scheduling is enabled */
166
boolean getEnableRecvScheduling();
167
168
/** Get optimizer options */
169
OptimizerOptions getOptimizerOptions();
170
171
/** Get build cost model */
172
long getBuildCostModel();
173
174
/** Get inference time */
175
long getInferenceTime();
176
177
/** Get placement period */
178
int getPlacementPeriod();
179
180
/** Check if timeline tracing is enabled */
181
boolean getEnableTimeline();
182
183
/** Get rewrite options */
184
RewriterConfig getRewriteOptions();
185
186
static Builder newBuilder();
187
188
static class Builder {
189
Builder setEnableRecvScheduling(boolean enable);
190
Builder setOptimizerOptions(OptimizerOptions options);
191
Builder setBuildCostModel(long model);
192
GraphOptions build();
193
}
194
}
195
```
196
197
### OptimizerOptions
198
199
Optimization configuration for graphs.
200
201
```java { .api }
202
/**
203
* Optimization configuration for graphs
204
*/
205
class OptimizerOptions {
206
/** Check if common subexpression elimination is enabled */
207
boolean getDoCommonSubexpressionElimination();
208
209
/** Check if constant folding is enabled */
210
boolean getDoConstantFolding();
211
212
/** Check if function inlining is enabled */
213
boolean getDoFunctionInlining();
214
215
/** Get optimization level */
216
Level getOptLevel();
217
218
/** Get global JIT level */
219
GlobalJitLevel getGlobalJitLevel();
220
221
static Builder newBuilder();
222
223
static class Builder {
224
Builder setDoCommonSubexpressionElimination(boolean enable);
225
Builder setDoConstantFolding(boolean enable);
226
Builder setDoFunctionInlining(boolean enable);
227
Builder setOptLevel(Level level);
228
OptimizerOptions build();
229
}
230
231
enum Level {
232
L1, L0
233
}
234
235
enum GlobalJitLevel {
236
DEFAULT,
237
OFF,
238
ON_1,
239
ON_2
240
}
241
}
242
```
243
244
### RPCOptions
245
246
RPC configuration for distributed execution.
247
248
```java { .api }
249
/**
250
* RPC configuration for distributed execution
251
*/
252
class RPCOptions {
253
/** Check if RPC compression is used */
254
boolean getUseRpcForInprocessMaster();
255
256
/** Get compression algorithm */
257
String getCompressionAlgorithm();
258
259
/** Get compression level */
260
int getCompressionLevel();
261
262
/** Check if streaming RPC is enabled */
263
boolean getStreamingRpc();
264
265
static Builder newBuilder();
266
267
static class Builder {
268
Builder setUseRpcForInprocessMaster(boolean use);
269
Builder setCompressionAlgorithm(String algorithm);
270
Builder setCompressionLevel(int level);
271
RPCOptions build();
272
}
273
}
274
```
275
276
### MetaGraphDef
277
278
Container for serializing complete TensorFlow models with metadata.
279
280
```java { .api }
281
/**
282
* Container for serializing complete models with metadata
283
*/
284
class MetaGraphDef {
285
/** Get metadata about the graph */
286
MetaInfoDef getMetaInfoDef();
287
288
/** Get the computation graph */
289
GraphDef getGraphDef();
290
291
/** Get saver configuration */
292
SaverDef getSaverDef();
293
294
/** Get collections mapping */
295
Map<String, CollectionDef> getCollectionDefMap();
296
297
/** Get input/output signatures */
298
Map<String, SignatureDef> getSignatureDefMap();
299
300
/** Get asset file definitions */
301
List<AssetFileDef> getAssetFileDefList();
302
303
/** Get object graph definition */
304
SavedObjectGraph getObjectGraphDef();
305
306
/** Create a new builder for constructing MetaGraphDef */
307
static Builder newBuilder();
308
309
/** Builder for constructing MetaGraphDef instances */
310
static class Builder {
311
Builder setMetaInfoDef(MetaInfoDef metaInfo);
312
Builder setGraphDef(GraphDef graphDef);
313
Builder setSaverDef(SaverDef saverDef);
314
Builder putCollectionDef(String key, CollectionDef value);
315
Builder putSignatureDef(String key, SignatureDef value);
316
Builder addAssetFileDef(AssetFileDef assetFile);
317
Builder setObjectGraphDef(SavedObjectGraph objectGraph);
318
MetaGraphDef build();
319
}
320
}
321
```
322
323
**Usage Examples:**
324
325
```java
326
import org.tensorflow.framework.*;
327
328
// First create a simple graph for demonstration
329
GraphDef trainedGraph = GraphDef.newBuilder()
330
.addNode(NodeDef.newBuilder()
331
.setName("input")
332
.setOp("Placeholder")
333
.putAttr("dtype", AttrValue.newBuilder().setType(DataType.DT_FLOAT).build()))
334
.addNode(NodeDef.newBuilder()
335
.setName("output")
336
.setOp("Identity")
337
.addInput("input"))
338
.build();
339
340
// Create a MetaGraphDef for the trained model
341
MetaGraphDef metaGraph = MetaGraphDef.newBuilder()
342
.setMetaInfoDef(MetaInfoDef.newBuilder()
343
.setMetaGraphVersion("1.15.0")
344
.addTags("serve")
345
.setTensorflowVersion("1.15.0")
346
.setTensorflowGitVersion("v1.15.0"))
347
.setGraphDef(trainedGraph)
348
.setSaverDef(SaverDef.newBuilder()
349
.setFilenameTensorName("save/Const:0")
350
.setSaveTensorName("save/Identity:0")
351
.setRestoreOpName("save/restore_all")
352
.setMaxToKeep(5)
353
.setKeepCheckpointEveryNHours(10000.0f))
354
.putSignatureDef("serving_default", SignatureDef.newBuilder()
355
.setMethodName("tensorflow/serving/predict")
356
.putInputs("input", TensorInfo.newBuilder()
357
.setName("input:0")
358
.setDtype(DataType.DT_FLOAT)
359
.setTensorShape(TensorShapeProto.newBuilder()
360
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1))
361
.addDim(TensorShapeProto.Dim.newBuilder().setSize(784)))
362
.build())
363
.putOutputs("output", TensorInfo.newBuilder()
364
.setName("output:0")
365
.setDtype(DataType.DT_FLOAT)
366
.setTensorShape(TensorShapeProto.newBuilder()
367
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1))
368
.addDim(TensorShapeProto.Dim.newBuilder().setSize(10)))
369
.build())
370
.build())
371
.build();
372
```
373
374
### SavedModel
375
376
High-level serialization format for TensorFlow models.
377
378
```java { .api }
379
/**
380
* High-level serialization format for TensorFlow models
381
*/
382
class SavedModel {
383
/** Get schema version */
384
long getSavedModelSchemaVersion();
385
386
/** Get MetaGraphs */
387
List<MetaGraphDef> getMetaGraphsList();
388
389
/** Create a new builder for constructing SavedModel */
390
static Builder newBuilder();
391
392
/** Builder for constructing SavedModel instances */
393
static class Builder {
394
Builder setSavedModelSchemaVersion(long version);
395
Builder addMetaGraphs(MetaGraphDef metaGraph);
396
Builder addAllMetaGraphs(Iterable<MetaGraphDef> metaGraphs);
397
SavedModel build();
398
}
399
}
400
```
401
402
### SignatureDef
403
404
Defines model input/output signatures for serving.
405
406
```java { .api }
407
/**
408
* Model input/output signature for serving
409
*/
410
class SignatureDef {
411
/** Get signature inputs */
412
Map<String, TensorInfo> getInputsMap();
413
414
/** Get signature outputs */
415
Map<String, TensorInfo> getOutputsMap();
416
417
/** Get method name (e.g., 'tensorflow/serving/predict') */
418
String getMethodName();
419
420
/** Create a new builder for constructing SignatureDef */
421
static Builder newBuilder();
422
423
/** Builder for constructing SignatureDef instances */
424
static class Builder {
425
Builder putInputs(String key, TensorInfo value);
426
Builder putOutputs(String key, TensorInfo value);
427
Builder setMethodName(String methodName);
428
SignatureDef build();
429
}
430
}
431
```
432
433
### TensorInfo
434
435
Information about tensors in signatures.
436
437
```java { .api }
438
/**
439
* Information about tensors in model signatures
440
*/
441
class TensorInfo {
442
/** Get tensor name */
443
String getName();
444
445
/** Get coordinate operation name (COO format) */
446
CooSparse getCooSparse();
447
448
/** Get composite tensor information */
449
CompositeTensor getCompositeTensor();
450
451
/** Get tensor data type */
452
DataType getDtype();
453
454
/** Get tensor shape */
455
TensorShapeProto getTensorShape();
456
457
/** Get which encoding is used */
458
EncodingCase getEncodingCase();
459
460
/** Create a new builder for constructing TensorInfo */
461
static Builder newBuilder();
462
463
/** Builder for constructing TensorInfo instances */
464
static class Builder {
465
Builder setName(String name);
466
Builder setCooSparse(CooSparse cooSparse);
467
Builder setCompositeTensor(CompositeTensor compositeTensor);
468
Builder setDtype(DataType dtype);
469
Builder setTensorShape(TensorShapeProto shape);
470
TensorInfo build();
471
}
472
473
enum EncodingCase {
474
NAME,
475
COO_SPARSE,
476
COMPOSITE_TENSOR,
477
ENCODING_NOT_SET
478
}
479
}
480
```
481
482
### SaverDef
483
484
Configuration for model checkpointing and restoration.
485
486
```java { .api }
487
/**
488
* Configuration for saving and restoring models
489
*/
490
class SaverDef {
491
/** Get filename tensor name */
492
String getFilenameTensorName();
493
494
/** Get save tensor name */
495
String getSaveTensorName();
496
497
/** Get restore operation name */
498
String getRestoreOpName();
499
500
/** Get maximum number of checkpoints to keep */
501
int getMaxToKeep();
502
503
/** Get sharded save option */
504
boolean getSharded();
505
506
/** Get checkpoint save hours interval */
507
float getKeepCheckpointEveryNHours();
508
509
/** Get checkpoint version */
510
CheckpointFormatVersion getVersion();
511
512
enum CheckpointFormatVersion {
513
LEGACY,
514
V1,
515
V2
516
}
517
}
518
```
519
520
**Usage Examples:**
521
522
```java
523
import org.tensorflow.framework.*;
524
525
// Create a serving signature for image classification
526
SignatureDef servingSignature = SignatureDef.newBuilder()
527
.setMethodName("tensorflow/serving/predict")
528
.putInputs("image", TensorInfo.newBuilder()
529
.setName("input_image:0")
530
.setDtype(DataType.DT_UINT8)
531
.setTensorShape(TensorShapeProto.newBuilder()
532
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1)) // batch
533
.addDim(TensorShapeProto.Dim.newBuilder().setSize(224)) // height
534
.addDim(TensorShapeProto.Dim.newBuilder().setSize(224)) // width
535
.addDim(TensorShapeProto.Dim.newBuilder().setSize(3))) // channels
536
.build())
537
.putOutputs("predictions", TensorInfo.newBuilder()
538
.setName("output_predictions:0")
539
.setDtype(DataType.DT_FLOAT)
540
.setTensorShape(TensorShapeProto.newBuilder()
541
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1)) // batch
542
.addDim(TensorShapeProto.Dim.newBuilder().setSize(1000))) // classes
543
.build())
544
.build();
545
546
// Create saver configuration
547
SaverDef saver = SaverDef.newBuilder()
548
.setFilenameTensorName("save/Const:0")
549
.setSaveTensorName("save/Identity:0")
550
.setRestoreOpName("save/restore_all")
551
.setMaxToKeep(10)
552
.setSharded(true)
553
.setKeepCheckpointEveryNHours(1.0f)
554
.setVersion(SaverDef.CheckpointFormatVersion.V2)
555
.build();
556
```