or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

index.mdneural-networks.md

index.mddocs/

0

# DeepLearning4j

1

2

DeepLearning4j is a comprehensive, Apache 2.0-licensed deep learning library for the JVM that provides neural network implementations, data processing capabilities, and distributed computing integrations. It supports both CPU and GPU execution, integrates with Hadoop and Spark, and offers a complete ecosystem for enterprise-grade deep learning applications.

3

4

## Package Information

5

6

- **Package Name**: deeplearning4j-parent

7

- **Package Type**: maven

8

- **Language**: Java

9

- **Installation**: Add dependency to your `pom.xml`:

10

11

```xml

12

<dependency>

13

<groupId>org.deeplearning4j</groupId>

14

<artifactId>deeplearning4j-core</artifactId>

15

<version>0.9.1</version>

16

</dependency>

17

18

<!-- Backend - choose one based on your hardware -->

19

<!-- CPU Backend -->

20

<dependency>

21

<groupId>org.nd4j</groupId>

22

<artifactId>nd4j-native</artifactId>

23

<version>0.9.1</version>

24

</dependency>

25

26

<!-- GPU Backend (requires CUDA 8.0) -->

27

<dependency>

28

<groupId>org.nd4j</groupId>

29

<artifactId>nd4j-cuda-8.0</artifactId>

30

<version>0.9.1</version>

31

</dependency>

32

33

<!-- For Spark integration -->

34

<dependency>

35

<groupId>org.deeplearning4j</groupId>

36

<artifactId>dl4j-spark_2.11</artifactId>

37

<version>0.9.1</version>

38

</dependency>

39

```

40

41

**System Requirements:**

42

- Java 7+ (Java 8+ recommended)

43

- Maven 3.0+ for building

44

- For GPU: CUDA-compatible GPU with compute capability 3.0+ (Kepler or newer)

45

- For GPU: CUDA 8.0 toolkit installed

46

47

## Core Imports

48

49

```java

50

// Core network classes

51

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

52

import org.deeplearning4j.nn.graph.ComputationGraph;

53

54

// Configuration

55

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;

56

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

57

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;

58

59

// Common layers

60

import org.deeplearning4j.nn.conf.layers.DenseLayer;

61

import org.deeplearning4j.nn.conf.layers.OutputLayer;

62

import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;

63

import org.deeplearning4j.nn.conf.layers.LSTM;

64

65

// Utilities and evaluation

66

import org.deeplearning4j.util.ModelSerializer;

67

import org.deeplearning4j.eval.Evaluation;

68

69

// Data handling (ND4J)

70

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

71

import org.nd4j.linalg.dataset.DataSet;

72

import org.nd4j.linalg.api.ndarray.INDArray;

73

74

// Loss functions and optimization (ND4J)

75

import org.nd4j.linalg.lossfunctions.LossFunctions;

76

import org.deeplearning4j.optimize.api.OptimizationAlgorithm;

77

import org.deeplearning4j.nn.conf.Updater;

78

```

79

80

## Basic Usage

81

82

```java

83

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

84

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;

85

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;

86

import org.deeplearning4j.nn.conf.layers.DenseLayer;

87

import org.deeplearning4j.nn.conf.layers.OutputLayer;

88

import org.deeplearning4j.optimize.api.OptimizationAlgorithm;

89

import org.deeplearning4j.nn.conf.Updater;

90

import org.nd4j.linalg.lossfunctions.LossFunctions;

91

92

// Create a simple feedforward neural network

93

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

94

.seed(123)

95

.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

96

.updater(Updater.NESTEROVS)

97

.list()

98

.layer(0, new DenseLayer.Builder()

99

.nIn(784)

100

.nOut(100)

101

.activation("relu")

102

.build())

103

.layer(1, new OutputLayer.Builder()

104

.nIn(100)

105

.nOut(10)

106

.activation("softmax")

107

.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)

108

.build())

109

.pretrain(false)

110

.backprop(true)

111

.build();

112

113

MultiLayerNetwork model = new MultiLayerNetwork(conf);

114

model.init();

115

116

// Train the model

117

model.fit(trainingData);

118

119

// Evaluate the model

120

Evaluation eval = model.evaluate(testData);

121

System.out.println(eval.stats());

122

```

123

124

## Architecture

125

126

DeepLearning4j is built around several key components:

127

128

- **Core Network Types**: `MultiLayerNetwork` for sequential architectures and `ComputationGraph` for complex graph-based networks

129

- **Configuration System**: Builder pattern for network configuration with `NeuralNetConfiguration` and layer-specific builders

130

- **Layer Types**: Comprehensive layer implementations including dense, convolutional, recurrent, and normalization layers

131

- **Data Handling**: Integration with ND4J for n-dimensional arrays and DataVec for data preprocessing

132

- **Backend Abstraction**: Support for both CPU (nd4j-native) and GPU (nd4j-cuda) execution via ND4J

133

- **Distribution Support**: Native integration with Apache Spark and Hadoop for distributed training

134

- **Model Management**: Serialization, import/export, and model zoo for pre-trained networks

135

136

## Capabilities

137

138

### Neural Networks

139

140

Core neural network construction with support for both sequential (MultiLayerNetwork) and graph-based (ComputationGraph) architectures.

141

142

```java { .api }

143

// Sequential networks

144

public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {

145

public MultiLayerNetwork(MultiLayerConfiguration conf);

146

public void fit(DataSetIterator iterator);

147

public INDArray output(INDArray input);

148

public Evaluation evaluate(DataSetIterator iterator);

149

}

150

151

// Graph networks

152

public class ComputationGraph implements Serializable, Model, NeuralNetwork {

153

public ComputationGraph(ComputationGraphConfiguration configuration);

154

public void fit(MultiDataSetIterator iterator);

155

public INDArray[] outputSingle(INDArray... input);

156

public Evaluation evaluate(DataSetIterator iterator);

157

}

158

```

159

160

[Neural Networks](./neural-networks.md)

161

162

### Network Configuration

163

164

Comprehensive configuration system using builder patterns for network architecture definition.

165

166

```java { .api }

167

public class NeuralNetConfiguration {

168

public static class Builder {

169

public Builder seed(long seed);

170

public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo);

171

public Builder updater(Updater updater);

172

public Builder learningRate(double learningRate);

173

public MultiLayerConfiguration.Builder list();

174

}

175

}

176

177

public class MultiLayerConfiguration {

178

public static class Builder {

179

public Builder layer(int layerIndex, Layer layer);

180

public Builder pretrain(boolean pretrain);

181

public Builder backprop(boolean backprop);

182

public MultiLayerConfiguration build();

183

}

184

}

185

```

186

187

[Configuration](./configuration.md)

188

189

### Layer Types

190

191

Comprehensive collection of layer implementations for various neural network architectures.

192

193

```java { .api }

194

// Dense layers

195

public class DenseLayer extends FeedForwardLayer {

196

public static class Builder extends FeedForwardLayer.Builder<Builder> {

197

public Builder nIn(int nIn);

198

public Builder nOut(int nOut);

199

public Builder activation(String activation);

200

}

201

}

202

203

// Convolutional layers

204

public class ConvolutionLayer extends SameDiffLayer {

205

public static class Builder {

206

public Builder kernelSize(int... kernelSize);

207

public Builder stride(int... stride);

208

public Builder padding(int... padding);

209

public Builder nIn(int nIn);

210

public Builder nOut(int nOut);

211

}

212

}

213

214

// Recurrent layers

215

public class LSTM extends BaseRecurrentLayer {

216

public static class Builder extends BaseRecurrentLayer.Builder<Builder> {

217

public Builder forgetGateBiasInit(double forgetGateBiasInit);

218

public Builder gateActivationFunction(String gateActivationFunction);

219

}

220

}

221

```

222

223

[Layers](./layers.md)

224

225

### Model Management

226

227

Model serialization, loading, and persistence utilities for production deployment.

228

229

```java { .api }

230

public class ModelSerializer {

231

public static void writeModel(Model model, String path, boolean saveUpdater) throws IOException;

232

public static void writeModel(Model model, File file, boolean saveUpdater) throws IOException;

233

public static MultiLayerNetwork restoreMultiLayerNetwork(String path) throws IOException;

234

public static MultiLayerNetwork restoreMultiLayerNetwork(File file) throws IOException;

235

public static ComputationGraph restoreComputationGraph(String path) throws IOException;

236

public static ComputationGraph restoreComputationGraph(File file) throws IOException;

237

}

238

```

239

240

[Model Management](./model-management.md)

241

242

### Evaluation and Metrics

243

244

Comprehensive evaluation metrics and performance measurement tools.

245

246

```java { .api }

247

public class Evaluation implements IEvaluation<Evaluation> {

248

public Evaluation(int numClasses);

249

public void eval(INDArray labels, INDArray predictions);

250

public double accuracy();

251

public double precision();

252

public double recall();

253

public double f1();

254

public String stats();

255

public String confusionToString();

256

}

257

258

public class RegressionEvaluation implements IEvaluation<RegressionEvaluation> {

259

public RegressionEvaluation(int numColumns);

260

public void eval(INDArray labels, INDArray predictions);

261

public double meanSquaredError(int column);

262

public double meanAbsoluteError(int column);

263

public double correlationR2(int column);

264

}

265

```

266

267

[Evaluation](./evaluation.md)

268

269

### Data Handling

270

271

Data loading, preprocessing, and batch management for training and inference.

272

273

```java { .api }

274

// Core data structures

275

public interface DataSetIterator extends Iterator<DataSet> {

276

boolean hasNext();

277

DataSet next();

278

DataSet next(int num);

279

int totalExamples();

280

int inputColumns();

281

int totalOutcomes();

282

void reset();

283

int batch();

284

}

285

286

// Built-in dataset loaders

287

public class MnistDataSetIterator implements DataSetIterator {

288

public MnistDataSetIterator(int batchSize, boolean train, int seed) throws IOException;

289

}

290

291

public class CifarDataSetIterator implements DataSetIterator {

292

public CifarDataSetIterator(int batchSize, boolean train) throws IOException;

293

}

294

```

295

296

[Data Handling](./data-handling.md)

297

298

### Distributed Computing

299

300

Native integration with Apache Spark and Hadoop for distributed training and inference.

301

302

```java { .api }

303

public class SparkDl4jMultiLayer {

304

public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf, TrainingMaster tm);

305

public MultiLayerNetwork fit(JavaRDD<DataSet> trainingData);

306

public JavaRDD<INDArray> predict(JavaRDD<INDArray> data);

307

public Evaluation evaluate(JavaRDD<DataSet> data);

308

}

309

310

public class SparkComputationGraph {

311

public SparkComputationGraph(JavaSparkContext sc, ComputationGraphConfiguration conf, TrainingMaster tm);

312

public ComputationGraph fit(JavaRDD<MultiDataSet> trainingData);

313

public JavaRDD<INDArray[]> predict(JavaRDD<INDArray[]> data);

314

}

315

```

316

317

[Distributed Computing](./distributed-computing.md)

318

319

## Types

320

321

```java { .api }

322

// Core interfaces

323

public interface NeuralNetwork {

324

INDArray output(INDArray input);

325

void fit(DataSetIterator iterator);

326

Evaluation evaluate(DataSetIterator iterator);

327

}

328

329

public interface Model extends Serializable {

330

void fit(DataSetIterator iterator);

331

INDArray output(INDArray input);

332

void save(File file) throws IOException;

333

}

334

335

// Network types

336

public interface Layer extends Serializable, Cloneable {

337

INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr);

338

long numParams();

339

void setParams(INDArray params);

340

INDArray getParams();

341

}

342

343

// Optimization

344

public enum OptimizationAlgorithm {

345

STOCHASTIC_GRADIENT_DESCENT,

346

CONJUGATE_GRADIENT,

347

LBFGS,

348

LINE_GRADIENT_DESCENT

349

}

350

```