Generated Java code for TensorFlow protocol buffers, providing type-safe access to TensorFlow's structured data formats including MetaGraphDef, ConfigProto, and other core TensorFlow data structures.
Configuration objects and data structures for TensorFlow sessions, saved models, and model metadata management. These types are essential for configuring TensorFlow execution and saving/loading trained models.
Session configuration parameters controlling TensorFlow execution behavior.
/**
* Session configuration parameters
*/
class ConfigProto {
/** Get device type to count mapping */
Map<String, Integer> getDeviceCountMap();
/** Get number of threads for intra-op parallelism */
int getIntraOpParallelismThreads();
/** Get number of threads for inter-op parallelism */
int getInterOpParallelismThreads();
/** Check if soft placement is enabled */
boolean getAllowSoftPlacement();
/** Check if device placement logging is enabled */
boolean getLogDevicePlacement();
/** Get GPU configuration options */
GPUOptions getGpuOptions();
/** Check if CUDA malloc is allowed to grow */
boolean getAllowSoftPlacement();
/** Get graph optimization options */
GraphOptions getGraphOptions();
/** Get operation timeout in milliseconds */
long getOperationTimeoutInMs();
/** Get RPC options for distributed execution */
RPCOptions getRpcOptions();
/** Get cluster configuration */
ClusterDef getClusterDef();
/** Check if XLA JIT compilation is enabled globally */
boolean getIsolateSessionState();
/** Create a new builder for constructing ConfigProto */
static Builder newBuilder();
/** Builder for constructing ConfigProto instances */
static class Builder {
Builder putDeviceCount(String key, int value);
Builder setIntraOpParallelismThreads(int threads);
Builder setInterOpParallelismThreads(int threads);
Builder setAllowSoftPlacement(boolean allow);
Builder setLogDevicePlacement(boolean log);
Builder setGpuOptions(GPUOptions options);
Builder setGraphOptions(GraphOptions options);
Builder setOperationTimeoutInMs(long timeout);
Builder setRpcOptions(RPCOptions options);
Builder setClusterDef(ClusterDef cluster);
ConfigProto build();
}
}Usage Examples:
import org.tensorflow.framework.*;
// Create a basic session configuration
ConfigProto config = ConfigProto.newBuilder()
.setAllowSoftPlacement(true)
.setLogDevicePlacement(false)
.setIntraOpParallelismThreads(4)
.setInterOpParallelismThreads(2)
.putDeviceCount("GPU", 1)
.putDeviceCount("CPU", 1)
.build();
// Configure GPU options
ConfigProto gpuConfig = ConfigProto.newBuilder()
.setGpuOptions(GPUOptions.newBuilder()
.setAllowGrowth(true)
.setPerProcessGpuMemoryFraction(0.8)
.addVisibleDeviceList("0") // Use only GPU 0
.build())
.build();
// Configure graph optimization
ConfigProto optimizedConfig = ConfigProto.newBuilder()
.setGraphOptions(GraphOptions.newBuilder()
.setEnableRecvScheduling(true)
.setOptimizerOptions(OptimizerOptions.newBuilder()
.setDoCommonSubexpressionElimination(true)
.setDoConstantFolding(true)
.setDoFunctionInlining(true)
.build())
.build())
.build();GPU-specific configuration options.
/**
* GPU configuration options
*/
class GPUOptions {
/** Get GPU memory fraction (0.0 to 1.0) */
double getPerProcessGpuMemoryFraction();
/** Check if GPU memory growth is allowed */
boolean getAllowGrowth();
/** Get GPU memory allocator type */
String getAllocatorType();
/** Get deferred deletion bytes */
long getDeferredDeletionBytes();
/** Get list of visible GPU device IDs */
List<String> getVisibleDeviceListList();
/** Get polling inactive delay in microseconds */
int getPollingInactiveDelayMsecs();
/** Check if force GPU compatible is enabled */
boolean getForceGpuCompatible();
/** Get experimental options */
Experimental getExperimental();
static class Experimental {
List<String> getVirtualDevicesList();
boolean getUseUnifiedMemory();
int getNumDevToDevCopyStreams();
String getCollectiveRingOrder();
int getTimestampedAllocator();
String getKernelTrackerMaxInterval();
int getKernelTrackerMaxBytes();
int getKernelTrackerMaxPending();
boolean getInternalFragmentationFraction();
boolean getUseMultiDeviceIterator();
}
}Graph optimization configuration options.
/**
* Graph optimization configuration options
*/
class GraphOptions {
/** Check if receiver scheduling is enabled */
boolean getEnableRecvScheduling();
/** Get optimizer options */
OptimizerOptions getOptimizerOptions();
/** Get build cost model */
long getBuildCostModel();
/** Get inference time */
long getInferenceTime();
/** Get placement period */
int getPlacementPeriod();
/** Check if timeline tracing is enabled */
boolean getEnableTimeline();
/** Get rewrite options */
RewriterConfig getRewriteOptions();
static Builder newBuilder();
static class Builder {
Builder setEnableRecvScheduling(boolean enable);
Builder setOptimizerOptions(OptimizerOptions options);
Builder setBuildCostModel(long model);
GraphOptions build();
}
}Optimization configuration for graphs.
/**
* Optimization configuration for graphs
*/
class OptimizerOptions {
/** Check if common subexpression elimination is enabled */
boolean getDoCommonSubexpressionElimination();
/** Check if constant folding is enabled */
boolean getDoConstantFolding();
/** Check if function inlining is enabled */
boolean getDoFunctionInlining();
/** Get optimization level */
Level getOptLevel();
/** Get global JIT level */
GlobalJitLevel getGlobalJitLevel();
static Builder newBuilder();
static class Builder {
Builder setDoCommonSubexpressionElimination(boolean enable);
Builder setDoConstantFolding(boolean enable);
Builder setDoFunctionInlining(boolean enable);
Builder setOptLevel(Level level);
OptimizerOptions build();
}
enum Level {
L1, L0
}
enum GlobalJitLevel {
DEFAULT,
OFF,
ON_1,
ON_2
}
}RPC configuration for distributed execution.
/**
* RPC configuration for distributed execution
*/
class RPCOptions {
/** Check if RPC compression is used */
boolean getUseRpcForInprocessMaster();
/** Get compression algorithm */
String getCompressionAlgorithm();
/** Get compression level */
int getCompressionLevel();
/** Check if streaming RPC is enabled */
boolean getStreamingRpc();
static Builder newBuilder();
static class Builder {
Builder setUseRpcForInprocessMaster(boolean use);
Builder setCompressionAlgorithm(String algorithm);
Builder setCompressionLevel(int level);
RPCOptions build();
}
}Container for serializing complete TensorFlow models with metadata.
/**
* Container for serializing complete models with metadata
*/
class MetaGraphDef {
/** Get metadata about the graph */
MetaInfoDef getMetaInfoDef();
/** Get the computation graph */
GraphDef getGraphDef();
/** Get saver configuration */
SaverDef getSaverDef();
/** Get collections mapping */
Map<String, CollectionDef> getCollectionDefMap();
/** Get input/output signatures */
Map<String, SignatureDef> getSignatureDefMap();
/** Get asset file definitions */
List<AssetFileDef> getAssetFileDefList();
/** Get object graph definition */
SavedObjectGraph getObjectGraphDef();
/** Create a new builder for constructing MetaGraphDef */
static Builder newBuilder();
/** Builder for constructing MetaGraphDef instances */
static class Builder {
Builder setMetaInfoDef(MetaInfoDef metaInfo);
Builder setGraphDef(GraphDef graphDef);
Builder setSaverDef(SaverDef saverDef);
Builder putCollectionDef(String key, CollectionDef value);
Builder putSignatureDef(String key, SignatureDef value);
Builder addAssetFileDef(AssetFileDef assetFile);
Builder setObjectGraphDef(SavedObjectGraph objectGraph);
MetaGraphDef build();
}
}Usage Examples:
import org.tensorflow.framework.*;
// First create a simple graph for demonstration
GraphDef trainedGraph = GraphDef.newBuilder()
.addNode(NodeDef.newBuilder()
.setName("input")
.setOp("Placeholder")
.putAttr("dtype", AttrValue.newBuilder().setType(DataType.DT_FLOAT).build()))
.addNode(NodeDef.newBuilder()
.setName("output")
.setOp("Identity")
.addInput("input"))
.build();
// Create a MetaGraphDef for the trained model
MetaGraphDef metaGraph = MetaGraphDef.newBuilder()
.setMetaInfoDef(MetaInfoDef.newBuilder()
.setMetaGraphVersion("1.15.0")
.addTags("serve")
.setTensorflowVersion("1.15.0")
.setTensorflowGitVersion("v1.15.0"))
.setGraphDef(trainedGraph)
.setSaverDef(SaverDef.newBuilder()
.setFilenameTensorName("save/Const:0")
.setSaveTensorName("save/Identity:0")
.setRestoreOpName("save/restore_all")
.setMaxToKeep(5)
.setKeepCheckpointEveryNHours(10000.0f))
.putSignatureDef("serving_default", SignatureDef.newBuilder()
.setMethodName("tensorflow/serving/predict")
.putInputs("input", TensorInfo.newBuilder()
.setName("input:0")
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1))
.addDim(TensorShapeProto.Dim.newBuilder().setSize(784)))
.build())
.putOutputs("output", TensorInfo.newBuilder()
.setName("output:0")
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1))
.addDim(TensorShapeProto.Dim.newBuilder().setSize(10)))
.build())
.build())
.build();High-level serialization format for TensorFlow models.
/**
* High-level serialization format for TensorFlow models
*/
class SavedModel {
/** Get schema version */
long getSavedModelSchemaVersion();
/** Get MetaGraphs */
List<MetaGraphDef> getMetaGraphsList();
/** Create a new builder for constructing SavedModel */
static Builder newBuilder();
/** Builder for constructing SavedModel instances */
static class Builder {
Builder setSavedModelSchemaVersion(long version);
Builder addMetaGraphs(MetaGraphDef metaGraph);
Builder addAllMetaGraphs(Iterable<MetaGraphDef> metaGraphs);
SavedModel build();
}
}Defines model input/output signatures for serving.
/**
* Model input/output signature for serving
*/
class SignatureDef {
/** Get signature inputs */
Map<String, TensorInfo> getInputsMap();
/** Get signature outputs */
Map<String, TensorInfo> getOutputsMap();
/** Get method name (e.g., 'tensorflow/serving/predict') */
String getMethodName();
/** Create a new builder for constructing SignatureDef */
static Builder newBuilder();
/** Builder for constructing SignatureDef instances */
static class Builder {
Builder putInputs(String key, TensorInfo value);
Builder putOutputs(String key, TensorInfo value);
Builder setMethodName(String methodName);
SignatureDef build();
}
}Information about tensors in signatures.
/**
* Information about tensors in model signatures
*/
class TensorInfo {
/** Get tensor name */
String getName();
/** Get coordinate operation name (COO format) */
CooSparse getCooSparse();
/** Get composite tensor information */
CompositeTensor getCompositeTensor();
/** Get tensor data type */
DataType getDtype();
/** Get tensor shape */
TensorShapeProto getTensorShape();
/** Get which encoding is used */
EncodingCase getEncodingCase();
/** Create a new builder for constructing TensorInfo */
static Builder newBuilder();
/** Builder for constructing TensorInfo instances */
static class Builder {
Builder setName(String name);
Builder setCooSparse(CooSparse cooSparse);
Builder setCompositeTensor(CompositeTensor compositeTensor);
Builder setDtype(DataType dtype);
Builder setTensorShape(TensorShapeProto shape);
TensorInfo build();
}
enum EncodingCase {
NAME,
COO_SPARSE,
COMPOSITE_TENSOR,
ENCODING_NOT_SET
}
}Configuration for model checkpointing and restoration.
/**
* Configuration for saving and restoring models
*/
class SaverDef {
/** Get filename tensor name */
String getFilenameTensorName();
/** Get save tensor name */
String getSaveTensorName();
/** Get restore operation name */
String getRestoreOpName();
/** Get maximum number of checkpoints to keep */
int getMaxToKeep();
/** Get sharded save option */
boolean getSharded();
/** Get checkpoint save hours interval */
float getKeepCheckpointEveryNHours();
/** Get checkpoint version */
CheckpointFormatVersion getVersion();
enum CheckpointFormatVersion {
LEGACY,
V1,
V2
}
}Usage Examples:
import org.tensorflow.framework.*;
// Create a serving signature for image classification
SignatureDef servingSignature = SignatureDef.newBuilder()
.setMethodName("tensorflow/serving/predict")
.putInputs("image", TensorInfo.newBuilder()
.setName("input_image:0")
.setDtype(DataType.DT_UINT8)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1)) // batch
.addDim(TensorShapeProto.Dim.newBuilder().setSize(224)) // height
.addDim(TensorShapeProto.Dim.newBuilder().setSize(224)) // width
.addDim(TensorShapeProto.Dim.newBuilder().setSize(3))) // channels
.build())
.putOutputs("predictions", TensorInfo.newBuilder()
.setName("output_predictions:0")
.setDtype(DataType.DT_FLOAT)
.setTensorShape(TensorShapeProto.newBuilder()
.addDim(TensorShapeProto.Dim.newBuilder().setSize(-1)) // batch
.addDim(TensorShapeProto.Dim.newBuilder().setSize(1000))) // classes
.build())
.build();
// Create saver configuration
SaverDef saver = SaverDef.newBuilder()
.setFilenameTensorName("save/Const:0")
.setSaveTensorName("save/Identity:0")
.setRestoreOpName("save/restore_all")
.setMaxToKeep(10)
.setSharded(true)
.setKeepCheckpointEveryNHours(1.0f)
.setVersion(SaverDef.CheckpointFormatVersion.V2)
.build();Install with Tessl CLI
npx tessl i tessl/maven-org-tensorflow--proto