DeepLearning4j is a comprehensive deep learning library for the JVM that provides neural network implementations, data processing capabilities, and distributed computing integrations.
—
Core neural network construction with support for both sequential (MultiLayerNetwork) and graph-based (ComputationGraph) architectures.
Sequential neural network architecture for feedforward, convolutional, or recurrent networks with linear layer stacking.
/**
* Multi-layer neural network implementation for sequential architectures
*/
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
/** Create network from configuration */
public MultiLayerNetwork(MultiLayerConfiguration conf);
/** Initialize network parameters */
public void init();
/** Train network on dataset iterator */
public void fit(DataSetIterator iterator);
/** Train network on single dataset */
public void fit(DataSet dataSet);
/** Get network output for input */
public INDArray output(INDArray input);
/** Get network output with training flag */
public INDArray output(INDArray input, boolean train);
/** Evaluate network performance */
public Evaluation evaluate(DataSetIterator iterator);
/** Get network score (loss) on dataset */
public double score(DataSet dataSet);
/** Get current network parameters */
public INDArray params();
/** Set network parameters */
public void setParams(INDArray params);
/** Get network gradients */
public Gradient gradient();
/** Clear network state (for RNNs) */
public void rnnClearPreviousState();
}Usage Examples:
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
// Create configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(Updater.ADAM)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(784)
.nOut(256)
.activation("relu")
.build())
.layer(1, new DenseLayer.Builder()
.nIn(256)
.nOut(128)
.activation("relu")
.build())
.layer(2, new OutputLayer.Builder()
.nIn(128)
.nOut(10)
.activation("softmax")
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.build())
.build();
// Create and initialize network
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
// Train network
network.fit(trainingDataIterator);
// Evaluate performance
Evaluation eval = network.evaluate(testDataIterator);
System.out.println("Accuracy: " + eval.accuracy());Complex graph-based neural network architecture for networks with multiple inputs/outputs, skip connections, merging/splitting.
/**
* Computation graph implementation for complex network architectures
*/
public class ComputationGraph implements Serializable, Model, NeuralNetwork {
/** Create graph from configuration */
public ComputationGraph(ComputationGraphConfiguration configuration);
/** Initialize graph parameters */
public void init();
/** Train graph on multi-dataset iterator */
public void fit(MultiDataSetIterator iterator);
/** Train graph on single multi-dataset */
public void fit(MultiDataSet multiDataSet);
/** Get single output for single input */
public INDArray outputSingle(INDArray input);
/** Get multiple outputs for multiple inputs */
public INDArray[] outputSingle(INDArray... input);
/** Get outputs with training flag */
public INDArray[] output(boolean train, INDArray... input);
/** Evaluate graph performance */
public Evaluation evaluate(DataSetIterator iterator);
/** Get graph score (loss) on multi-dataset */
public double score(MultiDataSet multiDataSet);
/** Get current graph parameters */
public INDArray params();
/** Set graph parameters */
public void setParams(INDArray params);
/** Clear graph state (for RNNs) */
public void rnnClearPreviousState();
}Usage Examples:
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
// Create complex graph configuration
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(Updater.ADAM)
.graphBuilder()
.addInputs("input")
.addLayer("dense1", new DenseLayer.Builder()
.nIn(784)
.nOut(256)
.activation("relu")
.build(), "input")
.addLayer("dense2", new DenseLayer.Builder()
.nIn(256)
.nOut(128)
.activation("relu")
.build(), "dense1")
.addLayer("output", new OutputLayer.Builder()
.nIn(128)
.nOut(10)
.activation("softmax")
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.build(), "dense2")
.setOutputs("output")
.build();
// Create and initialize graph
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
// Train graph
graph.fit(multiDataSetIterator);
// Get predictions
INDArray predictions = graph.outputSingle(inputData);Core interfaces that both MultiLayerNetwork and ComputationGraph implement.
/**
* Base interface for neural networks
*/
public interface NeuralNetwork {
/** Get network output for input */
INDArray output(INDArray input);
/** Train network on dataset iterator */
void fit(DataSetIterator iterator);
/** Evaluate network performance */
Evaluation evaluate(DataSetIterator iterator);
/** Get network parameters */
INDArray params();
/** Set network parameters */
void setParams(INDArray params);
}
/**
* Model interface for serializable models
*/
public interface Model extends Serializable {
/** Train model on dataset iterator */
void fit(DataSetIterator iterator);
/** Get model output for input */
INDArray output(INDArray input);
/** Save model to file */
void save(File file) throws IOException;
/** Get model parameters */
INDArray params();
/** Set model parameters */
void setParams(INDArray params);
}
/**
* Classifier interface for classification models
*/
public interface Classifier {
/** Get class predictions for input */
int[] predict(INDArray input);
/** Get class probability distributions */
INDArray output(INDArray input);
}// Configuration types
public class MultiLayerConfiguration implements Serializable {
// Configuration for sequential networks
}
public class ComputationGraphConfiguration implements Serializable {
// Configuration for graph networks
}
// Network state and training
public interface Gradient {
// Gradient information for backpropagation
INDArray getGradientFor(String variable);
Map<String, INDArray> gradientForVariable();
}Install with Tessl CLI
npx tessl i tessl/maven-org-deeplearning4j--deeplearning4j-parent@0.9.1