Keras model import functionality for DeepLearning4J
—
Comprehensive mapping of Keras layer types to DeepLearning4J layers. The library supports most common neural network layer types and provides automatic translation of parameters and configurations.
Maps Keras Dense layers to DeepLearning4J DenseLayer.
// Keras Dense layer configuration supported:
// - units (output_dim): number of output units
// - activation: activation function
// - use_bias: whether to use bias
// - kernel_initializer (init): weight initialization
// - bias_initializer: bias initialization
// - kernel_regularizer (W_regularizer): weight regularization
// - bias_regularizer (b_regularizer): bias regularization
// - activity_regularizer: output regularization
// - kernel_constraint: weight constraints
// - bias_constraint: bias constraintsSupported Parameters:
Maps Keras Convolution1D and Convolution2D layers to DeepLearning4J ConvolutionLayer.
// Keras Convolution layer configuration supported:
// - filters (nb_filter): number of convolution filters
// - kernel_size (nb_row, nb_col): convolution kernel size
// - strides (subsample): stride values
// - padding (border_mode): padding type ('same', 'valid')
// - activation: activation function
// - use_bias: whether to use bias
// - kernel_initializer: weight initialization
// - bias_initializer: bias initialization
// - kernel_regularizer: weight regularization
// - bias_regularizer: bias regularizationSupported Configurations:
Maps Keras pooling layers to DeepLearning4J pooling layers.
// Supported pooling types:
// - MaxPooling1D -> SubsamplingLayer with PoolingType.MAX
// - MaxPooling2D -> SubsamplingLayer with PoolingType.MAX
// - AveragePooling1D -> SubsamplingLayer with PoolingType.AVG
// - AveragePooling2D -> SubsamplingLayer with PoolingType.AVG
// - GlobalMaxPooling1D -> GlobalPoolingLayer with PoolingType.MAX
// - GlobalMaxPooling2D -> GlobalPoolingLayer with PoolingType.MAX
// - GlobalAveragePooling1D -> GlobalPoolingLayer with PoolingType.AVG
// - GlobalAveragePooling2D -> GlobalPoolingLayer with PoolingType.AVGConfiguration Options:
Maps Keras LSTM layers to DeepLearning4J LSTM layers.
// Keras LSTM layer configuration supported:
// - units: number of LSTM units
// - activation: activation function for gates
// - recurrent_activation: recurrent activation function
// - use_bias: whether to use bias
// - kernel_initializer: input weight initialization
// - recurrent_initializer: recurrent weight initialization
// - bias_initializer: bias initialization
// - dropout: input dropout rate
// - recurrent_dropout: recurrent dropout rate
// - return_sequences: whether to return full sequence
// - return_state: whether to return cell state
// - go_backwards: process sequence backwards
// - stateful: maintain state between batches
// - unroll: unroll the recurrent computationFeatures:
Maps Keras Activation layers to DeepLearning4J ActivationLayer.
// Supported activation functions:
// - relu -> ReLU
// - sigmoid -> Sigmoid
// - tanh -> Tanh
// - softmax -> Softmax
// - linear -> Identity
// - softplus -> Softplus
// - softsign -> Softsign
// - hard_sigmoid -> HardSigmoid
// - elu -> ELU
// - selu -> SELU
// - swish -> SwishMaps Keras Dropout layers to DeepLearning4J DropoutLayer.
// Keras Dropout configuration:
// - rate: dropout probability (0.0 to 1.0)
// - noise_shape: shape for dropout mask
// - seed: random seed for reproducibilityMaps Keras Flatten layers to DeepLearning4J preprocessors.
// Flattens multi-dimensional input to 1D
// Automatically handles different input shapes
// Maps to appropriate DL4J InputPreProcessorMaps Keras Embedding layers to DeepLearning4J EmbeddingLayer.
// Keras Embedding configuration:
// - input_dim: vocabulary size
// - output_dim: embedding dimension
// - embeddings_initializer: weight initialization
// - embeddings_regularizer: weight regularization
// - embeddings_constraint: weight constraints
// - mask_zero: mask zero values
// - input_length: input sequence lengthMaps Keras BatchNormalization layers to DeepLearning4J BatchNormalization.
// Keras BatchNormalization configuration:
// - axis: normalization axis
// - momentum: momentum for moving averages
// - epsilon: small constant for numerical stability
// - center: whether to use beta parameter
// - scale: whether to use gamma parameter
// - beta_initializer: beta initialization
// - gamma_initializer: gamma initialization
// - moving_mean_initializer: moving mean initialization
// - moving_variance_initializer: moving variance initialization
// - beta_regularizer: beta regularization
// - gamma_regularizer: gamma regularization
// - beta_constraint: beta constraints
// - gamma_constraint: gamma constraintsCustom implementation for Local Response Normalization (LRN).
// KerasLRN class provides:
// - alpha: normalization parameter
// - beta: normalization parameter
// - depth_radius: normalization radius
// - bias: bias parameterMaps Keras ZeroPadding1D and ZeroPadding2D layers to appropriate preprocessors.
// Zero padding configuration:
// - padding: padding values for each dimension
// - Supports symmetric and asymmetric paddingMaps Keras Merge layers to DeepLearning4J merge vertices.
// Supported merge modes:
// - add -> ElementWiseVertex with Add operation
// - multiply -> ElementWiseVertex with Product operation
// - average -> ElementWiseVertex with Average operation
// - maximum -> ElementWiseVertex with Max operation
// - concatenate -> MergeVertex
// - dot -> DotProductVertexThe library automatically detects Keras layer types and maps them to appropriate DeepLearning4J layers:
// KerasLayer factory method
public static KerasLayer getKerasLayerFromConfig(
Map<String, Object> layerConfig,
boolean enforceTrainingConfig
) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException;Each layer type has specific configuration mapping:
Weights are automatically transferred with proper shape handling:
// Weight copying process:
// 1. Extract weights from HDF5 format
// 2. Handle parameter naming conventions (TensorFlow vs Theano backends)
// 3. Reshape weights to match DL4J expectations
// 4. Apply to corresponding DL4J layers| Keras Layer | DL4J Mapping | Support Level |
|---|---|---|
| Dense | DenseLayer | Full |
| Convolution1D | ConvolutionLayer | Full |
| Convolution2D | ConvolutionLayer | Full |
| MaxPooling1D | SubsamplingLayer | Full |
| MaxPooling2D | SubsamplingLayer | Full |
| AveragePooling1D | SubsamplingLayer | Full |
| AveragePooling2D | SubsamplingLayer | Full |
| GlobalMaxPooling1D | GlobalPoolingLayer | Full |
| GlobalMaxPooling2D | GlobalPoolingLayer | Full |
| GlobalAveragePooling1D | GlobalPoolingLayer | Full |
| GlobalAveragePooling2D | GlobalPoolingLayer | Full |
| LSTM | LSTM | Full |
| Dropout | DropoutLayer | Full |
| Activation | ActivationLayer | Full |
| Flatten | Preprocessor | Full |
| Embedding | EmbeddingLayer | Full |
| BatchNormalization | BatchNormalization | Full |
| Merge | MergeVertex/ElementWiseVertex | Full |
| ZeroPadding1D | Preprocessor | Full |
| ZeroPadding2D | Preprocessor | Full |
| Input | InputType | Full |
try {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.h5", true);
} catch (UnsupportedKerasConfigurationException e) {
System.out.println("Unsupported feature: " + e.getMessage());
// Try with relaxed enforcement
ComputationGraph model = KerasModelImport.importKerasModelAndWeights("model.h5", false);
System.out.println("Model imported with warnings");
}For unsupported layer types, you can extend the library:
// Example custom layer implementation
public class MyCustomKerasLayer extends KerasLayer {
public MyCustomKerasLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
// Custom implementation
}
@Override
public Layer getLayer() throws UnsupportedKerasConfigurationException {
// Return appropriate DL4J layer
}
@Override
public InputType getOutputType(InputType... inputTypes)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
// Return output type
}
}Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-modelimport