0
# Neural Networks
1
2
Core neural network construction with support for both sequential (MultiLayerNetwork) and graph-based (ComputationGraph) architectures.
3
4
## Capabilities
5
6
### MultiLayerNetwork
7
8
Sequential neural network architecture for feedforward, convolutional, or recurrent networks with linear layer stacking.
9
10
```java { .api }
11
/**
12
* Multi-layer neural network implementation for sequential architectures
13
*/
14
public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork {
15
/** Create network from configuration */
16
public MultiLayerNetwork(MultiLayerConfiguration conf);
17
18
/** Initialize network parameters */
19
public void init();
20
21
/** Train network on dataset iterator */
22
public void fit(DataSetIterator iterator);
23
24
/** Train network on single dataset */
25
public void fit(DataSet dataSet);
26
27
/** Get network output for input */
28
public INDArray output(INDArray input);
29
30
/** Get network output with training flag */
31
public INDArray output(INDArray input, boolean train);
32
33
/** Evaluate network performance */
34
public Evaluation evaluate(DataSetIterator iterator);
35
36
/** Get network score (loss) on dataset */
37
public double score(DataSet dataSet);
38
39
/** Get current network parameters */
40
public INDArray params();
41
42
/** Set network parameters */
43
public void setParams(INDArray params);
44
45
/** Get network gradients */
46
public Gradient gradient();
47
48
/** Clear network state (for RNNs) */
49
public void rnnClearPreviousState();
50
}
51
```
52
53
**Usage Examples:**
54
55
```java
56
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
57
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
58
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
59
import org.deeplearning4j.nn.conf.layers.DenseLayer;
60
import org.deeplearning4j.nn.conf.layers.OutputLayer;
61
62
// Create configuration
63
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
64
.seed(123)
65
.updater(Updater.ADAM)
66
.list()
67
.layer(0, new DenseLayer.Builder()
68
.nIn(784)
69
.nOut(256)
70
.activation("relu")
71
.build())
72
.layer(1, new DenseLayer.Builder()
73
.nIn(256)
74
.nOut(128)
75
.activation("relu")
76
.build())
77
.layer(2, new OutputLayer.Builder()
78
.nIn(128)
79
.nOut(10)
80
.activation("softmax")
81
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
82
.build())
83
.build();
84
85
// Create and initialize network
86
MultiLayerNetwork network = new MultiLayerNetwork(conf);
87
network.init();
88
89
// Train network
90
network.fit(trainingDataIterator);
91
92
// Evaluate performance
93
Evaluation eval = network.evaluate(testDataIterator);
94
System.out.println("Accuracy: " + eval.accuracy());
95
```
96
97
### ComputationGraph
98
99
Complex graph-based neural network architecture for networks with multiple inputs/outputs, skip connections, merging/splitting.
100
101
```java { .api }
102
/**
103
* Computation graph implementation for complex network architectures
104
*/
105
public class ComputationGraph implements Serializable, Model, NeuralNetwork {
106
/** Create graph from configuration */
107
public ComputationGraph(ComputationGraphConfiguration configuration);
108
109
/** Initialize graph parameters */
110
public void init();
111
112
/** Train graph on multi-dataset iterator */
113
public void fit(MultiDataSetIterator iterator);
114
115
/** Train graph on single multi-dataset */
116
public void fit(MultiDataSet multiDataSet);
117
118
/** Get single output for single input */
119
public INDArray outputSingle(INDArray input);
120
121
/** Get multiple outputs for multiple inputs */
122
public INDArray[] outputSingle(INDArray... input);
123
124
/** Get outputs with training flag */
125
public INDArray[] output(boolean train, INDArray... input);
126
127
/** Evaluate graph performance */
128
public Evaluation evaluate(DataSetIterator iterator);
129
130
/** Get graph score (loss) on multi-dataset */
131
public double score(MultiDataSet multiDataSet);
132
133
/** Get current graph parameters */
134
public INDArray params();
135
136
/** Set graph parameters */
137
public void setParams(INDArray params);
138
139
/** Clear graph state (for RNNs) */
140
public void rnnClearPreviousState();
141
}
142
```
143
144
**Usage Examples:**
145
146
```java
147
import org.deeplearning4j.nn.graph.ComputationGraph;
148
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
149
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
150
151
// Create complex graph configuration
152
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
153
.seed(123)
154
.updater(Updater.ADAM)
155
.graphBuilder()
156
.addInputs("input")
157
.addLayer("dense1", new DenseLayer.Builder()
158
.nIn(784)
159
.nOut(256)
160
.activation("relu")
161
.build(), "input")
162
.addLayer("dense2", new DenseLayer.Builder()
163
.nIn(256)
164
.nOut(128)
165
.activation("relu")
166
.build(), "dense1")
167
.addLayer("output", new OutputLayer.Builder()
168
.nIn(128)
169
.nOut(10)
170
.activation("softmax")
171
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
172
.build(), "dense2")
173
.setOutputs("output")
174
.build();
175
176
// Create and initialize graph
177
ComputationGraph graph = new ComputationGraph(conf);
178
graph.init();
179
180
// Train graph
181
graph.fit(multiDataSetIterator);
182
183
// Get predictions
184
INDArray predictions = graph.outputSingle(inputData);
185
```
186
187
### Network Interfaces
188
189
Core interfaces that both MultiLayerNetwork and ComputationGraph implement.
190
191
```java { .api }
192
/**
193
* Base interface for neural networks
194
*/
195
public interface NeuralNetwork {
196
/** Get network output for input */
197
INDArray output(INDArray input);
198
199
/** Train network on dataset iterator */
200
void fit(DataSetIterator iterator);
201
202
/** Evaluate network performance */
203
Evaluation evaluate(DataSetIterator iterator);
204
205
/** Get network parameters */
206
INDArray params();
207
208
/** Set network parameters */
209
void setParams(INDArray params);
210
}
211
212
/**
213
* Model interface for serializable models
214
*/
215
public interface Model extends Serializable {
216
/** Train model on dataset iterator */
217
void fit(DataSetIterator iterator);
218
219
/** Get model output for input */
220
INDArray output(INDArray input);
221
222
/** Save model to file */
223
void save(File file) throws IOException;
224
225
/** Get model parameters */
226
INDArray params();
227
228
/** Set model parameters */
229
void setParams(INDArray params);
230
}
231
232
/**
233
* Classifier interface for classification models
234
*/
235
public interface Classifier {
236
/** Get class predictions for input */
237
int[] predict(INDArray input);
238
239
/** Get class probability distributions */
240
INDArray output(INDArray input);
241
}
242
```
243
244
## Types
245
246
```java { .api }
247
// Configuration types
248
public class MultiLayerConfiguration implements Serializable {
249
// Configuration for sequential networks
250
}
251
252
public class ComputationGraphConfiguration implements Serializable {
253
// Configuration for graph networks
254
}
255
256
// Network state and training
257
public interface Gradient {
258
// Gradient information for backpropagation
259
INDArray getGradientFor(String variable);
260
Map<String, INDArray> gradientForVariable();
261
}
262
```