A model zoo for deeplearning4j that provides access to pre-trained neural network models
npx @tessl/cli install tessl/maven-org-deeplearning4j--deeplearning4j-zoo@0.9.0A model zoo for deeplearning4j that provides access to pre-trained neural network models. The library offers popular CNN architectures like AlexNet, VGG, ResNet, and Inception models, as well as RNN models for text generation, with support for downloading and using pre-trained weights from various datasets.
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>0.9.1</version>
</dependency>import org.deeplearning4j.zoo.*;
import org.deeplearning4j.zoo.model.*;
import org.deeplearning4j.zoo.model.helper.*;
import org.deeplearning4j.zoo.util.imagenet.*;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.activations.Activation;For specific models:
import org.deeplearning4j.zoo.model.AlexNet;
import org.deeplearning4j.zoo.model.VGG16;
import org.deeplearning4j.zoo.model.ResNet50;
import org.deeplearning4j.zoo.model.helper.FaceNetHelper;
import org.deeplearning4j.zoo.model.helper.InceptionResNetHelper;
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;import org.deeplearning4j.zoo.*;
import org.deeplearning4j.zoo.model.*;
import org.deeplearning4j.nn.api.Model;
// Instantiate a specific model
AlexNet alexNet = new AlexNet(1000, 42, 1);
Model model = alexNet.init();
// Load pre-trained weights
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model pretrainedModel = vgg16.initPretrained(PretrainedType.IMAGENET);
// Use model selector for multiple models
Map<ZooType, ZooModel> models = ModelSelector.select(ZooType.CNN, 10);
for (Map.Entry<ZooType, ZooModel> entry : models.entrySet()) {
ZooModel zooModel = entry.getValue();
Model initializedModel = zooModel.init();
}
// Get model metadata
ModelMetaData metadata = alexNet.metaData();
int[][] inputShape = metadata.getInputShape();
int numOutputs = metadata.getNumOutputs();The deeplearning4j-zoo is built around several key components:
InstantiableModel defines the contract for all zoo models, ensuring consistent initialization and metadata accessZooModel abstract class provides common functionality including pre-trained model downloading, caching, and checksum verificationModelSelector utility for programmatically selecting and instantiating multiple modelsZooType, PretrainedType) for categorizing models and pre-trained weight typesFoundation interface and abstract class defining how all zoo models are instantiated, configured, and used. Includes pre-trained model support with automatic downloading and caching.
interface InstantiableModel {
void setInputShape(int[][] inputShape);
Model init();
ModelMetaData metaData();
ZooType zooType();
Class<? extends Model> modelType();
String pretrainedUrl(PretrainedType pretrainedType);
long pretrainedChecksum(PretrainedType pretrainedType);
}
abstract class ZooModel<T> implements InstantiableModel {
static File ROOT_CACHE_DIR;
boolean pretrainedAvailable(PretrainedType pretrainedType);
Model initPretrained();
Model initPretrained(PretrainedType pretrainedType);
}Collection of convolutional neural network implementations including AlexNet, VGG16/19, ResNet50, GoogLeNet, LeNet, and other popular architectures for image classification and computer vision tasks.
class AlexNet extends ZooModel {
AlexNet(int numLabels, long seed, int iterations);
AlexNet(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}
class VGG16 extends ZooModel {
VGG16(int numLabels, long seed, int iterations);
VGG16(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}
class ResNet50 extends ZooModel {
ResNet50(int numLabels, long seed, int iterations);
ResNet50(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}Recurrent neural network implementations including LSTM models for text generation and sequence modeling tasks.
class TextGenerationLSTM extends ZooModel {
TextGenerationLSTM(int numLabels, long seed, int iterations);
TextGenerationLSTM(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}Tools for programmatically selecting, instantiating, and working with multiple zoo models, including helper classes for building custom architectures.
class ModelSelector {
static Map<ZooType, ZooModel> select(ZooType zooType);
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels);
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, WorkspaceMode workspaceMode);
static Map<ZooType, ZooModel> select(ZooType zooType, int numLabels, int seed, int iterations, WorkspaceMode workspaceMode);
static Map<ZooType, ZooModel> select(WorkspaceMode workspaceMode, ZooType... zooTypes);
static Map<ZooType, ZooModel> select(ZooType... zooTypes);
static Map<ZooType, ZooModel> select(int numLabels, int seed, int iterations, WorkspaceMode workspaceMode, ZooType... zooTypes);
}Utilities for working with ImageNet classifications, including label decoding and prediction interpretation for models trained on ImageNet dataset.
class ImageNetLabels {
ImageNetLabels();
String getLabel(int n);
String decodePredictions(INDArray predictions);
}class ModelMetaData {
ModelMetaData(int[][] inputShape, int numOutputs, ZooType zooType);
int[][] getInputShape();
int getNumOutputs();
ZooType getZooType();
boolean useMDS();
}
enum ZooType {
ALL, CNN, SIMPLECNN, ALEXNET, LENET, GOOGLENET, VGG16, VGG19,
RESNET50, INCEPTIONRESNETV1, FACENETNN4SMALL2, RNN, TEXTGENLSTM
}
enum PretrainedType {
IMAGENET, MNIST, CIFAR10, VGGFACE
}
enum WorkspaceMode {
NONE, SINGLE, SEPARATE
}
class MultiLayerConfiguration {
// Network configuration for MultiLayerNetwork models
}
class ComputationGraphConfiguration {
// Network configuration for ComputationGraph models
class GraphBuilder {
// Builder for constructing computation graphs
GraphBuilder addInputs(String... inputs);
GraphBuilder addLayer(String layerName, Layer layer, String... inputs);
GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... inputs);
ComputationGraphConfiguration build();
}
}
abstract class Layer {
// Base class for all neural network layers
}
class ConvolutionLayer extends Layer {
enum AlgoMode {
NO_WORKSPACE, PREFER_FASTEST, USER_SPECIFIED
}
}
class SubsamplingLayer extends Layer {
enum PoolingType {
MAX, AVG, SUM, PNORM
}
}
class DenseLayer extends Layer {
// Fully connected layer
}
class BatchNormalization extends Layer {
// Batch normalization layer
}
class ActivationLayer extends Layer {
// Activation function layer
}
abstract class GraphVertex {
// Base class for graph vertices
}
class MergeVertex extends GraphVertex {
// Vertex that merges multiple inputs
}
class ElementWiseVertex extends GraphVertex {
// Vertex for element-wise operations
}
abstract class Model {
// Base class for all DeepLearning4j models
INDArray output(INDArray input);
void fit(DataSetIterator iterator);
}
class MultiLayerNetwork extends Model {
// Feed-forward neural network implementation
}
class ComputationGraph extends Model {
// Computation graph implementation for complex architectures
}
enum Activation {
RELU, TANH, SIGMOID, SOFTMAX, IDENTITY, LEAKYRELU
}