A model zoo for deeplearning4j that provides access to pre-trained neural network models
—
Foundation interface and abstract class defining how all zoo models are instantiated, configured, and used. Includes pre-trained model support with automatic downloading and caching.
Core interface that all zoo models must implement, defining the contract for model instantiation and metadata access.
/**
* Interface for defining a model that can be instantiated and return information about itself.
*/
interface InstantiableModel {
/**
* Sets the input shape for the model
* @param inputShape 2D array representing input dimensions
*/
void setInputShape(int[][] inputShape);
/**
* Initializes and returns the model instance
* @return Model instance ready for training or inference
*/
Model init();
/**
* Returns metadata about the model including input shapes and outputs
* @return ModelMetaData containing model information
*/
ModelMetaData metaData();
/**
* Returns the zoo type classification for this model
* @return ZooType enum value
*/
ZooType zooType();
/**
* Returns the DeepLearning4j model class type
* @return Class extending Model (MultiLayerNetwork or ComputationGraph)
*/
Class<? extends Model> modelType();
/**
* Returns URL for downloading pre-trained weights
* @param pretrainedType Type of pre-trained weights to download
* @return URL string or null if not available
*/
String pretrainedUrl(PretrainedType pretrainedType);
/**
* Returns checksum for verifying downloaded pre-trained weights
* @param pretrainedType Type of pre-trained weights
* @return Checksum value for verification
*/
long pretrainedChecksum(PretrainedType pretrainedType);
}Usage Example:
AlexNet model = new AlexNet(1000, 42, 1);
ModelMetaData metadata = model.metaData();
int[][] inputShape = metadata.getInputShape(); // [[3, 224, 224]]
ZooType type = model.zooType(); // ZooType.ALEXNET
Class<? extends Model> modelClass = model.modelType(); // MultiLayerNetwork.classBase implementation providing common functionality for pre-trained model downloading, caching, and initialization.
/**
* A zoo model is instantiable, returns information about itself, and can download
* pretrained models if available.
*/
abstract class ZooModel<T> implements InstantiableModel {
/**
* Root directory for caching downloaded models (defaults to ~/.deeplearning4j/)
*/
static File ROOT_CACHE_DIR;
/**
* Checks if pre-trained weights are available for the specified type
* @param pretrainedType Type of pre-trained weights to check
* @return true if pre-trained weights are available
*/
boolean pretrainedAvailable(PretrainedType pretrainedType);
/**
* Initializes model with ImageNet pre-trained weights (default)
* @return Model instance with pre-trained weights loaded
* @throws IOException if download or loading fails
*/
Model initPretrained() throws IOException;
/**
* Initializes model with specified pre-trained weights
* @param pretrainedType Type of pre-trained weights to load
* @return Model instance with pre-trained weights loaded
* @throws IOException if download or loading fails
* @throws UnsupportedOperationException if pre-trained weights not available
*/
Model initPretrained(PretrainedType pretrainedType) throws IOException;
}Usage Examples:
// Check if pre-trained weights are available
VGG16 vgg16 = new VGG16(1000, 42, 1);
boolean hasImageNet = vgg16.pretrainedAvailable(PretrainedType.IMAGENET); // true
boolean hasMNIST = vgg16.pretrainedAvailable(PretrainedType.MNIST); // false
// Load pre-trained model (downloads and caches automatically)
Model pretrainedModel = vgg16.initPretrained(PretrainedType.IMAGENET);
// Load default pre-trained weights (ImageNet)
Model defaultPretrained = vgg16.initPretrained();
// Custom cache directory
ZooModel.ROOT_CACHE_DIR = new File("/custom/cache/path");Contains metadata describing a model including input shapes, outputs, and configuration information.
/**
* Metadata describing a model, including input shapes. This is helpful for instantiating
* the model programmatically and ensuring appropriate inputs are used.
*/
class ModelMetaData {
/**
* Creates model metadata
* @param inputShape Array of input shape dimensions
* @param numOutputs Number of output classes/labels
* @param zooType Zoo type classification
*/
ModelMetaData(int[][] inputShape, int numOutputs, ZooType zooType);
/**
* Gets the input shape dimensions
* @return 2D array representing input shapes
*/
int[][] getInputShape();
/**
* Gets the number of output classes
* @return Number of outputs
*/
int getNumOutputs();
/**
* Gets the zoo type classification
* @return ZooType enum value
*/
ZooType getZooType();
/**
* Indicates if the model should use MultiDataSet (multiple inputs)
* @return true if multiple inputs are expected
*/
boolean useMDS();
}Usage Example:
AlexNet alexNet = new AlexNet(10, 42, 1);
ModelMetaData metadata = alexNet.metaData();
int[][] inputShape = metadata.getInputShape(); // [[3, 224, 224]]
int numOutputs = metadata.getNumOutputs(); // 1
ZooType zooType = metadata.getZooType(); // ZooType.CNN
boolean multipleInputs = metadata.useMDS(); // falseInstall with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-zoo