DeepLearning4j is a comprehensive deep learning library for the JVM that provides neural network implementations, data processing capabilities, and distributed computing integrations.
npx @tessl/cli install tessl/maven-org-deeplearning4j--deeplearning4j-parent@0.9.00
# DeepLearning4j
1
2
DeepLearning4j is a comprehensive, Apache 2.0-licensed deep learning library for the JVM that provides neural network implementations, data processing capabilities, and distributed computing integrations. It supports both CPU and GPU execution, integrates with Hadoop and Spark, and offers a complete ecosystem for enterprise-grade deep learning applications.
3
4
## Package Information
5
6
- **Package Name**: deeplearning4j-parent
7
- **Package Type**: maven
8
- **Language**: Java
9
- **Installation**: Add dependency to your `pom.xml`:
10
11
```xml
12
<dependency>
13
<groupId>org.deeplearning4j</groupId>
14
<artifactId>deeplearning4j-core</artifactId>
15
<version>0.9.1</version>
16
</dependency>
17
18
<!-- Backend - choose one based on your hardware -->
19
<!-- CPU Backend -->
20
<dependency>
21
<groupId>org.nd4j</groupId>
22
<artifactId>nd4j-native</artifactId>
23
<version>0.9.1</version>
24
</dependency>
25
26
<!-- GPU Backend (requires CUDA 8.0) -->
27
<dependency>
28
<groupId>org.nd4j</groupId>
29
<artifactId>nd4j-cuda-8.0</artifactId>
30
<version>0.9.1</version>
31
</dependency>
32
33
<!-- For Spark integration -->
34
<dependency>
35
<groupId>org.deeplearning4j</groupId>
36
<artifactId>dl4j-spark_2.11</artifactId>
37
<version>0.9.1</version>
38
</dependency>
39
```
40
41
**System Requirements:**
42
- Java 7+ (Java 8+ recommended)
43
- Maven 3.0+ for building
44
- For GPU: CUDA-compatible GPU with compute capability 3.0+ (Kepler or newer)
45
- For GPU: CUDA 8.0 toolkit installed
46
47
## Core Imports
48
49
```java
50
// Core network classes
51
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
52
import org.deeplearning4j.nn.graph.ComputationGraph;
53
54
// Configuration
55
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
56
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
57
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
58
59
// Common layers
60
import org.deeplearning4j.nn.conf.layers.DenseLayer;
61
import org.deeplearning4j.nn.conf.layers.OutputLayer;
62
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
63
import org.deeplearning4j.nn.conf.layers.LSTM;
64
65
// Utilities and evaluation
66
import org.deeplearning4j.util.ModelSerializer;
67
import org.deeplearning4j.eval.Evaluation;
68
69
// Data handling (ND4J)
70
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
71
import org.nd4j.linalg.dataset.DataSet;
72
import org.nd4j.linalg.api.ndarray.INDArray;
73
74
// Loss functions and optimization (ND4J)
75
import org.nd4j.linalg.lossfunctions.LossFunctions;
76
import org.deeplearning4j.optimize.api.OptimizationAlgorithm;
77
import org.deeplearning4j.nn.conf.Updater;
78
```
79
80
## Basic Usage
81
82
```java
83
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
84
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
85
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
86
import org.deeplearning4j.nn.conf.layers.DenseLayer;
87
import org.deeplearning4j.nn.conf.layers.OutputLayer;
88
import org.deeplearning4j.optimize.api.OptimizationAlgorithm;
89
import org.deeplearning4j.nn.conf.Updater;
90
import org.nd4j.linalg.lossfunctions.LossFunctions;
91
92
// Create a simple feedforward neural network
93
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
94
.seed(123)
95
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
96
.updater(Updater.NESTEROVS)
97
.list()
98
.layer(0, new DenseLayer.Builder()
99
.nIn(784)
100
.nOut(100)
101
.activation("relu")
102
.build())
103
.layer(1, new OutputLayer.Builder()
104
.nIn(100)
105
.nOut(10)
106
.activation("softmax")
107
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
108
.build())
109
.pretrain(false)
110
.backprop(true)
111
.build();
112
113
MultiLayerNetwork model = new MultiLayerNetwork(conf);
114
model.init();
115
116
// Train the model
117
model.fit(trainingData);
118
119
// Evaluate the model
120
Evaluation eval = model.evaluate(testData);
121
System.out.println(eval.stats());
122
```
123
124
## Architecture
125
126
DeepLearning4j is built around several key components:
127
128
- **Core Network Types**: `MultiLayerNetwork` for sequential architectures and `ComputationGraph` for complex graph-based networks
129
- **Configuration System**: Builder pattern for network configuration with `NeuralNetConfiguration` and layer-specific builders
130
- **Layer Types**: Comprehensive layer implementations including dense, convolutional, recurrent, and normalization layers
131
- **Data Handling**: Integration with ND4J for n-dimensional arrays and DataVec for data preprocessing
132
- **Backend Abstraction**: Support for both CPU (nd4j-native) and GPU (nd4j-cuda) execution via ND4J
133
- **Distribution Support**: Native integration with Apache Spark and Hadoop for distributed training
134
- **Model Management**: Serialization, import/export, and model zoo for pre-trained networks
135
136
## Capabilities
137
138
### Neural Networks
139
140
Core neural network construction with support for both sequential (MultiLayerNetwork) and graph-based (ComputationGraph) architectures.
141
142
```java { .api }
143
// Sequential networks
144
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
145
public MultiLayerNetwork(MultiLayerConfiguration conf);
146
public void fit(DataSetIterator iterator);
147
public INDArray output(INDArray input);
148
public Evaluation evaluate(DataSetIterator iterator);
149
}
150
151
// Graph networks
152
public class ComputationGraph implements Serializable, Model, NeuralNetwork {
153
public ComputationGraph(ComputationGraphConfiguration configuration);
154
public void fit(MultiDataSetIterator iterator);
155
public INDArray[] outputSingle(INDArray... input);
156
public Evaluation evaluate(DataSetIterator iterator);
157
}
158
```
159
160
[Neural Networks](./neural-networks.md)
161
162
### Network Configuration
163
164
Comprehensive configuration system using builder patterns for network architecture definition.
165
166
```java { .api }
167
public class NeuralNetConfiguration {
168
public static class Builder {
169
public Builder seed(long seed);
170
public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo);
171
public Builder updater(Updater updater);
172
public Builder learningRate(double learningRate);
173
public MultiLayerConfiguration.Builder list();
174
}
175
}
176
177
public class MultiLayerConfiguration {
178
public static class Builder {
179
public Builder layer(int layerIndex, Layer layer);
180
public Builder pretrain(boolean pretrain);
181
public Builder backprop(boolean backprop);
182
public MultiLayerConfiguration build();
183
}
184
}
185
```
186
187
[Configuration](./configuration.md)
188
189
### Layer Types
190
191
Comprehensive collection of layer implementations for various neural network architectures.
192
193
```java { .api }
194
// Dense layers
195
public class DenseLayer extends FeedForwardLayer {
196
public static class Builder extends FeedForwardLayer.Builder<Builder> {
197
public Builder nIn(int nIn);
198
public Builder nOut(int nOut);
199
public Builder activation(String activation);
200
}
201
}
202
203
// Convolutional layers
204
public class ConvolutionLayer extends SameDiffLayer {
205
public static class Builder {
206
public Builder kernelSize(int... kernelSize);
207
public Builder stride(int... stride);
208
public Builder padding(int... padding);
209
public Builder nIn(int nIn);
210
public Builder nOut(int nOut);
211
}
212
}
213
214
// Recurrent layers
215
public class LSTM extends BaseRecurrentLayer {
216
public static class Builder extends BaseRecurrentLayer.Builder<Builder> {
217
public Builder forgetGateBiasInit(double forgetGateBiasInit);
218
public Builder gateActivationFunction(String gateActivationFunction);
219
}
220
}
221
```
222
223
[Layers](./layers.md)
224
225
### Model Management
226
227
Model serialization, loading, and persistence utilities for production deployment.
228
229
```java { .api }
230
public class ModelSerializer {
231
public static void writeModel(Model model, String path, boolean saveUpdater) throws IOException;
232
public static void writeModel(Model model, File file, boolean saveUpdater) throws IOException;
233
public static MultiLayerNetwork restoreMultiLayerNetwork(String path) throws IOException;
234
public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException;
235
public static ComputationGraph restoreComputationGraph(String path) throws IOException;
236
public static ComputationGraph restoreComputationGraph(File file) throws IOException;
237
}
238
```
239
240
[Model Management](./model-management.md)
241
242
### Evaluation and Metrics
243
244
Comprehensive evaluation metrics and performance measurement tools.
245
246
```java { .api }
247
public class Evaluation implements IEvaluation<Evaluation> {
248
public Evaluation(int numClasses);
249
public void eval(INDArray labels, INDArray predictions);
250
public double accuracy();
251
public double precision();
252
public double recall();
253
public double f1();
254
public String stats();
255
public String confusionToString();
256
}
257
258
public class RegressionEvaluation implements IEvaluation<RegressionEvaluation> {
259
public RegressionEvaluation(int numColumns);
260
public void eval(INDArray labels, INDArray predictions);
261
public double meanSquaredError(int column);
262
public double meanAbsoluteError(int column);
263
public double correlationR2(int column);
264
}
265
```
266
267
[Evaluation](./evaluation.md)
268
269
### Data Handling
270
271
Data loading, preprocessing, and batch management for training and inference.
272
273
```java { .api }
274
// Core data structures
275
public interface DataSetIterator extends Iterator<DataSet> {
276
boolean hasNext();
277
DataSet next();
278
DataSet next(int num);
279
int totalExamples();
280
int inputColumns();
281
int totalOutcomes();
282
void reset();
283
int batch();
284
}
285
286
// Built-in dataset loaders
287
public class MnistDataSetIterator implements DataSetIterator {
288
public MnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException;
289
}
290
291
public class CifarDataSetIterator implements DataSetIterator {
292
public CifarDataSetIterator(int batchSize, boolean train) throws IOException;
293
}
294
```
295
296
[Data Handling](./data-handling.md)
297
298
### Distributed Computing
299
300
Native integration with Apache Spark and Hadoop for distributed training and inference.
301
302
```java { .api }
303
public class SparkDl4jMultiLayer {
304
public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf, TrainingMaster tm);
305
public MultiLayerNetwork fit(JavaRDD<DataSet> trainingData);
306
public JavaRDD<INDArray> predict(JavaRDD<INDArray> data);
307
public Evaluation evaluate(JavaRDD<DataSet> data);
308
}
309
310
public class SparkComputationGraph {
311
public SparkComputationGraph(JavaSparkContext sc, ComputationGraphConfiguration conf, TrainingMaster tm);
312
public ComputationGraph fit(JavaRDD<MultiDataSet> trainingData);
313
public JavaRDD<INDArray[]> predict(JavaRDD<INDArray[]> data);
314
}
315
```
316
317
[Distributed Computing](./distributed-computing.md)
318
319
## Types
320
321
```java { .api }
322
// Core interfaces
323
public interface NeuralNetwork {
324
INDArray output(INDArray input);
325
void fit(DataSetIterator iterator);
326
Evaluation evaluate(DataSetIterator iterator);
327
}
328
329
public interface Model extends Serializable {
330
void fit(DataSetIterator iterator);
331
INDArray output(INDArray input);
332
void save(File file) throws IOException;
333
}
334
335
// Network types
336
public interface Layer extends Serializable, Cloneable {
337
INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr);
338
long numParams();
339
void setParams(INDArray params);
340
INDArray getParams();
341
}
342
343
// Optimization
344
public enum OptimizationAlgorithm {
345
STOCHASTIC_GRADIENT_DESCENT,
346
CONJUGATE_GRADIENT,
347
LBFGS,
348
LINE_GRADIENT_DESCENT
349
}
350
```