A model zoo for deeplearning4j that provides access to pre-trained neural network models
—
Tools for programmatically selecting, instantiating, and working with multiple zoo models, including helper classes for building custom architectures.
Utility class for selecting and instantiating multiple zoo models based on type. Provides various overloaded methods for different configuration needs.
/**
* Helper class for selecting multiple models from the zoo.
*/
class ModelSelector {
/**
* Select models by type with default configuration
* @param zooType Type of models to select
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType zooType);
/**
* Select models by type with custom label count
* @param zooType Type of models to select
* @param numLabels Number of output classes
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels);
/**
* Select models by type with workspace mode
* @param zooType Type of models to select
* @param numLabels Number of output classes
* @param workspaceMode Memory workspace configuration
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, WorkspaceMode workspaceMode);
/**
* Select models by type with training parameters
* @param zooType Type of models to select
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations);
/**
* Select models by type with full parameter control
* @param zooType Type of models to select
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations, WorkspaceMode workspaceMode);
/**
* Select specific model types with workspace mode
* @param workspaceMode Memory workspace configuration
* @param zooTypes Specific model types to select
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(WorkspaceMode workspaceMode, ZooType... zooTypes);
/**
* Select specific model types with default configuration
* @param zooTypes Specific model types to select
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(ZooType... zooTypes);
/**
* Select specific model types with full parameter control
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
* @param zooTypes Specific model types to select
* @return Map of ZooType to ZooModel instances
*/
static Map<ZooType, ZooModel> select(int numLabels, int seed, int iterations, WorkspaceMode workspaceMode, ZooType... zooTypes);
}Usage Examples:
// Select all CNN models with default settings
Map<ZooType, ZooModel> cnnModels = ModelSelector.select(ZooType.CNN);
// Returns: AlexNet, VGG16, VGG19, ResNet50, GoogLeNet, LeNet, SimpleCNN
// Select all models (CNN + RNN)
Map<ZooType, ZooModel> allModels = ModelSelector.select(ZooType.ALL);
// Select specific models
Map<ZooType, ZooModel> specificModels = ModelSelector.select(
ZooType.ALEXNET,
ZooType.VGG16,
ZooType.RESNET50
);
// Select with custom configuration
Map<ZooType, ZooModel> customModels = ModelSelector.select(
ZooType.CNN,
10, // 10 classes
42, // seed
100, // iterations
WorkspaceMode.SINGLE
);
// Iterate through selected models
for (Map.Entry<ZooType, ZooModel> entry : cnnModels.entrySet()) {
ZooType type = entry.getKey();
ZooModel model = entry.getValue();
System.out.println("Model: " + type);
Model initializedModel = model.init();
ModelMetaData metadata = model.metaData();
System.out.println("Input shape: " + Arrays.deepToString(metadata.getInputShape()));
}Classification system for different model types and categories.
/**
* Enumerator for choosing different models, and different types of models.
*/
enum ZooType {
/** All available models */
ALL,
/** All CNN models */
CNN,
/** Simple CNN architecture */
SIMPLECNN,
/** AlexNet architecture */
ALEXNET,
/** LeNet architecture */
LENET,
/** GoogLeNet/Inception architecture */
GOOGLENET,
/** VGG16 architecture */
VGG16,
/** VGG19 architecture */
VGG19,
/** ResNet50 architecture */
RESNET50,
/** InceptionResNetV1 architecture */
INCEPTIONRESNETV1,
/** FaceNet NN4 Small2 architecture */
FACENETNN4SMALL2,
/** All RNN models */
RNN,
/** Text generation LSTM */
TEXTGENLSTM
}Model Type Hierarchies:
// CNN models include:
ModelSelector.select(ZooType.CNN); // Returns all CNN architectures
// - SIMPLECNN, ALEXNET, LENET, GOOGLENET, VGG16, VGG19, RESNET50
// RNN models include:
ModelSelector.select(ZooType.RNN); // Returns all RNN architectures
// - TEXTGENLSTM
// ALL includes both CNN and RNN:
ModelSelector.select(ZooType.ALL); // Returns all available modelsTypes of pre-trained model weights available for supported models.
/**
* Enumerator for choosing different pre-trained weight types.
*/
enum PretrainedType {
/** ImageNet dataset pre-trained weights (1000 classes) */
IMAGENET,
/** MNIST dataset pre-trained weights (10 digit classes) */
MNIST,
/** CIFAR-10 dataset pre-trained weights (10 object classes) */
CIFAR10,
/** VGGFace dataset pre-trained weights (face recognition) */
VGGFACE
}Pre-trained Weight Availability:
VGG16 vgg16 = new VGG16(1000, 42, 1);
// Check which pre-trained weights are available
boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET); // true
boolean hasCIFAR10 = vgg16.pretrainedAvailable(PretrainedType.CIFAR10); // true
boolean hasVGGFace = vgg16.pretrainedAvailable(PretrainedType.VGGFACE); // true
boolean hasMNIST = vgg16.pretrainedAvailable(PretrainedType.MNIST); // false
// Load specific pre-trained weights
Model imageNetModel = vgg16.initPretrained(PretrainedType.IMAGENET);
Model cifar10Model = vgg16.initPretrained(PretrainedType.CIFAR10);Utility class for building Inception-style layers used in FaceNet and other advanced architectures.
/**
* Helper class for building Inception-style modules used in FaceNet models.
* Provides pre-configured layers and graph building utilities.
*/
class FaceNetHelper {
/**
* Returns base module name for inception layers
* @return "inception"
*/
static String getModuleName();
/**
* Returns namespaced module name
* @param layerName Name of the specific layer
* @return Formatted module name
*/
static String getModuleName(String layerName);
/**
* Creates 1x1 convolution layer
* @param in Number of input channels
* @param out Number of output channels
* @param bias Bias initialization value
* @return ConvolutionLayer configured as 1x1 convolution
*/
static ConvolutionLayer conv1x1(int in, int out, double bias);
/**
* Creates 3x3 convolution layer
* @param in Number of input channels
* @param out Number of output channels
* @param bias Bias initialization value
* @return ConvolutionLayer configured as 3x3 convolution
*/
static ConvolutionLayer conv3x3(int in, int out, double bias);
/**
* Creates 5x5 convolution layer
* @param in Number of input channels
* @param out Number of output channels
* @param bias Bias initialization value
* @return ConvolutionLayer configured as 5x5 convolution
*/
static ConvolutionLayer conv5x5(int in, int out, double bias);
/**
* Creates 7x7 convolution layer
* @param in Number of input channels
* @param out Number of output channels
* @param bias Bias initialization value
* @return ConvolutionLayer configured as 7x7 convolution
*/
static ConvolutionLayer conv7x7(int in, int out, double bias);
/**
* Creates average pooling layer
* @param size Pool size (NxN)
* @param stride Stride for pooling
* @return SubsamplingLayer configured for average pooling
*/
static SubsamplingLayer avgPoolNxN(int size, int stride);
/**
* Creates max pooling layer
* @param size Pool size (NxN)
* @param stride Stride for pooling
* @return SubsamplingLayer configured for max pooling
*/
static SubsamplingLayer maxPoolNxN(int size, int stride);
/**
* Creates p-norm pooling layer
* @param pNorm P-norm value
* @param size Pool size (NxN)
* @param stride Stride for pooling
* @return SubsamplingLayer configured for p-norm pooling
*/
static SubsamplingLayer pNormNxN(int pNorm, int size, int stride);
/**
* Creates fully connected (dense) layer
* @param in Number of input units
* @param out Number of output units
* @param dropOut Dropout rate
* @return DenseLayer with specified configuration
*/
static DenseLayer fullyConnected(int in, int out, double dropOut);
/**
* Creates batch normalization layer
* @param in Number of input channels
* @param out Number of output channels
* @return BatchNormalization layer
*/
static BatchNormalization batchNorm(int in, int out);
/**
* Appends complete Inception module to a computation graph with default parameters
* @param graph Existing graph builder
* @param moduleLayerName Name for this inception module
* @param inputSize Number of input channels
* @param kernelSize Array of kernel sizes for different paths
* @param kernelStride Array of strides for different paths
* @param outputSize Array of output sizes for different paths
* @param reduceSize Array of reduction sizes for different paths
* @param poolingType Type of pooling to use
* @param transferFunction Activation function
* @param inputLayer Name of input layer to connect to
* @return Updated GraphBuilder with inception module added
*/
static ComputationGraphConfiguration.GraphBuilder appendGraph(
ComputationGraphConfiguration.GraphBuilder graph,
String moduleLayerName,
int inputSize,
int[] kernelSize,
int[] kernelStride,
int[] outputSize,
int[] reduceSize,
SubsamplingLayer.PoolingType poolingType,
Activation transferFunction,
String inputLayer
);
/**
* Appends complete Inception module to a computation graph with p-norm pooling
* @param graph Existing graph builder
* @param moduleLayerName Name for this inception module
* @param inputSize Number of input channels
* @param kernelSize Array of kernel sizes for different paths
* @param kernelStride Array of strides for different paths
* @param outputSize Array of output sizes for different paths
* @param reduceSize Array of reduction sizes for different paths
* @param poolingType Type of pooling to use
* @param pNorm P-norm value (if using p-norm pooling)
* @param transferFunction Activation function
* @param inputLayer Name of input layer to connect to
* @return Updated GraphBuilder with inception module added
*/
static ComputationGraphConfiguration.GraphBuilder appendGraph(
ComputationGraphConfiguration.GraphBuilder graph,
String moduleLayerName,
int inputSize,
int[] kernelSize,
int[] kernelStride,
int[] outputSize,
int[] reduceSize,
SubsamplingLayer.PoolingType poolingType,
int pNorm,
Activation transferFunction,
String inputLayer
);
/**
* Appends complete Inception module to a computation graph with custom pooling parameters
* @param graph Existing graph builder
* @param moduleLayerName Name for this inception module
* @param inputSize Number of input channels
* @param kernelSize Array of kernel sizes for different paths
* @param kernelStride Array of strides for different paths
* @param outputSize Array of output sizes for different paths
* @param reduceSize Array of reduction sizes for different paths
* @param poolingType Type of pooling to use
* @param poolSize Size of pooling window
* @param poolStride Stride for pooling
* @param transferFunction Activation function
* @param inputLayer Name of input layer to connect to
* @return Updated GraphBuilder with inception module added
*/
static ComputationGraphConfiguration.GraphBuilder appendGraph(
ComputationGraphConfiguration.GraphBuilder graph,
String moduleLayerName,
int inputSize,
int[] kernelSize,
int[] kernelStride,
int[] outputSize,
int[] reduceSize,
SubsamplingLayer.PoolingType poolingType,
int poolSize,
int poolStride,
Activation transferFunction,
String inputLayer
);
/**
* Appends complete Inception module to a computation graph with full parameter control
* @param graph Existing graph builder
* @param moduleLayerName Name for this inception module
* @param inputSize Number of input channels
* @param kernelSize Array of kernel sizes for different paths
* @param kernelStride Array of strides for different paths
* @param outputSize Array of output sizes for different paths
* @param reduceSize Array of reduction sizes for different paths
* @param poolingType Type of pooling to use
* @param pNorm P-norm value (if using p-norm pooling)
* @param poolSize Size of pooling window
* @param poolStride Stride for pooling
* @param transferFunction Activation function
* @param inputLayer Name of input layer to connect to
* @return Updated GraphBuilder with inception module added
*/
static ComputationGraphConfiguration.GraphBuilder appendGraph(
ComputationGraphConfiguration.GraphBuilder graph,
String moduleLayerName,
int inputSize,
int[] kernelSize,
int[] kernelStride,
int[] outputSize,
int[] reduceSize,
SubsamplingLayer.PoolingType poolingType,
int pNorm,
int poolSize,
int poolStride,
Activation transferFunction,
String inputLayer
);
}Usage Example:
// Building custom architecture with Inception modules
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("input");
// Add custom Inception module
graph = FaceNetHelper.appendGraph(
graph,
"inception_1", // module name
64, // input channels
new int[]{3, 5}, // kernel sizes
new int[]{1, 1}, // strides
new int[]{128, 64}, // output sizes
new int[]{32, 16, 8}, // reduction sizes
SubsamplingLayer.PoolingType.MAX,
0, // p-norm (not used for MAX pooling)
3, // pool size
1, // pool stride
Activation.RELU, // activation
"input" // input layer name
);Helper class for building Inception-ResNet architectures that combine Inception modules with residual connections.
/**
* Helper class for building Inception-ResNet modules that combine residual shortcuts
* with Inception-style networks. Based on the Inception-ResNet paper.
*/
class InceptionResNetHelper {
/**
* Creates layer name with block and iteration naming
* @param blockName Name of the inception block
* @param layerName Name of the specific layer
* @param i Iteration/block number
* @return Formatted layer name
*/
static String nameLayer(String blockName, String layerName, int i);
/**
* Appends Inception-ResNet A blocks to a computation graph
* @param graph Existing graph builder
* @param blockName Name for this inception block
* @param scale Number of blocks to add
* @param activationScale Scaling factor for activations
* @param input Name of input layer to connect to
* @return Updated GraphBuilder with Inception-ResNet A blocks added
*/
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResA(
ComputationGraphConfiguration.GraphBuilder graph,
String blockName,
int scale,
double activationScale,
String input
);
/**
* Appends Inception-ResNet B blocks to a computation graph
* @param graph Existing graph builder
* @param blockName Name for this inception block
* @param scale Number of blocks to add
* @param activationScale Scaling factor for activations
* @param input Name of input layer to connect to
* @return Updated GraphBuilder with Inception-ResNet B blocks added
*/
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResB(
ComputationGraphConfiguration.GraphBuilder graph,
String blockName,
int scale,
double activationScale,
String input
);
/**
* Appends Inception-ResNet C blocks to a computation graph
* @param graph Existing graph builder
* @param blockName Name for this inception block
* @param scale Number of blocks to add
* @param activationScale Scaling factor for activations
* @param input Name of input layer to connect to
* @return Updated GraphBuilder with Inception-ResNet C blocks added
*/
static ComputationGraphConfiguration.GraphBuilder inceptionV1ResC(
ComputationGraphConfiguration.GraphBuilder graph,
String blockName,
int scale,
double activationScale,
String input
);
}Usage Example:
// Building InceptionResNet architecture
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("input");
// Add Inception-ResNet A blocks
graph = InceptionResNetHelper.inceptionV1ResA(
graph,
"resnet_a", // block name
3, // number of blocks
0.1, // activation scaling
"input" // input layer
);Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-zoo