A model zoo for deeplearning4j that provides access to pre-trained neural network models
—
Collection of convolutional neural network implementations including AlexNet, VGG16/19, ResNet50, GoogLeNet, LeNet, and other popular architectures for image classification and computer vision tasks.
Implementation of AlexNet CNN architecture based on "ImageNet Classification with Deep Convolutional Neural Networks". Suitable for image classification tasks with 224x224x3 input images.
/**
* AlexNet CNN architecture implementation
*/
class AlexNet extends ZooModel {
/**
* Creates AlexNet with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
AlexNet(int numLabels, long seed, int iterations);
/**
* Creates AlexNet with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
AlexNet(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return MultiLayerConfiguration for the AlexNet architecture
*/
MultiLayerConfiguration conf();
}Usage Example:
// Create AlexNet for 1000-class ImageNet classification
AlexNet alexNet = new AlexNet(1000, 42, 1);
MultiLayerNetwork model = (MultiLayerNetwork) alexNet.init();
// Custom configuration with workspace mode
AlexNet customAlexNet = new AlexNet(10, 123, 1, WorkspaceMode.SINGLE);
ModelMetaData metadata = customAlexNet.metaData();
// Input shape: [3, 224, 224] (RGB, 224x224 images)Implementation of VGG16 CNN architecture with support for pre-trained weights from ImageNet, CIFAR-10, and VGGFace datasets.
/**
* VGG-16 CNN architecture implementation with pre-trained weights support
*/
class VGG16 extends ZooModel {
/**
* Creates VGG16 with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
VGG16(int numLabels, long seed, int iterations);
/**
* Creates VGG16 with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
VGG16(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return MultiLayerConfiguration for the VGG16 architecture
*/
MultiLayerConfiguration conf();
}Pre-trained Weights Support:
PretrainedType.IMAGENET - ImageNet classification weightsPretrainedType.CIFAR10 - CIFAR-10 classification weightsPretrainedType.VGGFACE - VGGFace recognition weightsUsage Example:
// Create VGG16 and load ImageNet pre-trained weights
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model pretrainedModel = vgg16.initPretrained(PretrainedType.IMAGENET);
// Check available pre-trained weights
boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET); // true
boolean hasCIFAR10 = vgg16.pretrainedAvailable(PretrainedType.CIFAR10); // true
boolean hasVGGFace = vgg16.pretrainedAvailable(PretrainedType.VGGFACE); // true
// Create for custom number of classes
VGG16 customVGG16 = new VGG16(10, 42, 1);
MultiLayerNetwork customModel = (MultiLayerNetwork) customVGG16.init();Implementation of VGG19 CNN architecture, deeper variant of VGG16 with 19 layers.
/**
* VGG-19 CNN architecture implementation
*/
class VGG19 extends ZooModel {
/**
* Creates VGG19 with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
VGG19(int numLabels, long seed, int iterations);
/**
* Creates VGG19 with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
VGG19(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return MultiLayerConfiguration for the VGG19 architecture
*/
MultiLayerConfiguration conf();
}Implementation of ResNet50 CNN architecture with residual connections for training very deep networks.
/**
* ResNet50 CNN architecture implementation with residual connections
*/
class ResNet50 extends ZooModel {
/**
* Creates ResNet50 with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
ResNet50(int numLabels, long seed, int iterations);
/**
* Creates ResNet50 with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
ResNet50(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}Implementation of GoogLeNet/Inception CNN architecture with inception modules for efficient computation.
/**
* GoogLeNet/Inception CNN architecture implementation
*/
class GoogLeNet extends ZooModel {
/**
* Creates GoogLeNet with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
GoogLeNet(int numLabels, long seed, int iterations);
/**
* Creates GoogLeNet with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
GoogLeNet(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return ComputationGraphConfiguration for the GoogLeNet architecture
*/
ComputationGraphConfiguration conf();
}Implementation of LeNet CNN architecture, one of the earliest successful CNNs for digit recognition.
/**
* LeNet CNN architecture implementation for digit recognition
*/
class LeNet extends ZooModel {
/**
* Creates LeNet with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
LeNet(int numLabels, long seed, int iterations);
/**
* Creates LeNet with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
LeNet(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return MultiLayerConfiguration for the LeNet architecture
*/
MultiLayerConfiguration conf();
}Implementation of a simple CNN architecture for basic image classification tasks.
/**
* Simple CNN architecture implementation
*/
class SimpleCNN extends ZooModel {
/**
* Creates SimpleCNN with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
SimpleCNN(int numLabels, long seed, int iterations);
/**
* Creates SimpleCNN with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
SimpleCNN(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return MultiLayerConfiguration for the SimpleCNN architecture
*/
MultiLayerConfiguration conf();
}Implementation of InceptionResNetV1 combining Inception modules with residual connections.
/**
* InceptionResNetV1 architecture combining Inception modules with residual connections
*/
class InceptionResNetV1 extends ZooModel {
/**
* Creates InceptionResNetV1 with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
InceptionResNetV1(int numLabels, long seed, int iterations);
/**
* Creates InceptionResNetV1 with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
InceptionResNetV1(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
}Implementation of FaceNet NN4 Small2 architecture specifically designed for face recognition tasks.
/**
* FaceNet NN4 Small2 architecture for face recognition
*/
class FaceNetNN4Small2 extends ZooModel {
/**
* Creates FaceNetNN4Small2 with basic configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
*/
FaceNetNN4Small2(int numLabels, long seed, int iterations);
/**
* Creates FaceNetNN4Small2 with workspace mode configuration
* @param numLabels Number of output classes
* @param seed Random seed for reproducibility
* @param iterations Number of training iterations
* @param workspaceMode Memory workspace configuration
*/
FaceNetNN4Small2(int numLabels, long seed, int iterations, WorkspaceMode workspaceMode);
/**
* Returns the network configuration
* @return ComputationGraphConfiguration for the FaceNetNN4Small2 architecture
*/
ComputationGraphConfiguration conf();
}Basic Model Creation:
// Create any CNN model with standard configuration
AlexNet alexNet = new AlexNet(1000, 42, 1);
VGG16 vgg16 = new VGG16(1000, 42, 1);
ResNet50 resNet50 = new ResNet50(1000, 42, 1);
// Initialize models
Model alexNetModel = alexNet.init();
Model vgg16Model = vgg16.init();
Model resNet50Model = resNet50.init();Custom Configuration:
// Custom number of classes for your dataset
int numClasses = 10; // e.g., for CIFAR-10
VGG16 customVGG16 = new VGG16(numClasses, 123, 1, WorkspaceMode.SINGLE);
// Get input requirements
ModelMetaData metadata = customVGG16.metaData();
int[][] inputShape = metadata.getInputShape(); // [3, 224, 224]Pre-trained Models:
// Load pre-trained weights
VGG16 vgg16 = new VGG16(1000, 42, 1);
Model pretrainedModel = vgg16.initPretrained(PretrainedType.IMAGENET);
// Check what pre-trained weights are available
boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET);Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-zoo