0
# Regression
1
2
Supervised learning algorithms for predicting continuous values. Smile Core provides comprehensive regression capabilities from traditional linear models to advanced ensemble methods, kernel machines, and neural networks.
3
4
## Capabilities
5
6
### Core Regression Interface
7
8
All regression algorithms implement the unified `Regression<T>` interface, providing consistent prediction methods and optional online learning support.
9
10
```java { .api }
11
/**
12
* Base regression interface for all supervised learning algorithms
13
* @param <T> the type of input objects
14
*/
15
interface Regression<T> extends ToDoubleFunction<T>, Serializable {
16
/** Predict the target value for input */
17
double predict(T x);
18
19
/** Online learning update (if supported) */
20
default void update(T x, double y);
21
22
/** Create ensemble of multiple regressors */
23
static <T> Regression<T> ensemble(Regression<T>... regressors);
24
}
25
```
26
27
### Linear Regression Models
28
29
Family of linear regression algorithms with various regularization techniques.
30
31
```java { .api }
32
/**
33
* Ordinary Least Squares regression
34
*/
35
class OLS implements Regression<double[]> {
36
/** Train OLS regression */
37
public static OLS fit(double[][] x, double[] y);
38
39
/** Train with intercept control */
40
public static OLS fit(double[][] x, double[] y, boolean intercept);
41
42
/** Predict target value */
43
public double predict(double[] x);
44
45
/** Get model coefficients */
46
public double[] coefficients();
47
48
/** Get intercept term */
49
public double intercept();
50
51
/** Get R-squared value */
52
public double RSquared();
53
54
/** Get adjusted R-squared */
55
public double adjustedRSquared();
56
57
/** Get residual sum of squares */
58
public double RSS();
59
60
/** Get total sum of squares */
61
public double TSS();
62
}
63
64
/**
65
* Ridge regression with L2 regularization
66
*/
67
class RidgeRegression implements Regression<double[]> {
68
/** Train ridge regression */
69
public static RidgeRegression fit(double[][] x, double[] y, double lambda);
70
71
/** Predict target value */
72
public double predict(double[] x);
73
74
/** Get model coefficients */
75
public double[] coefficients();
76
77
/** Get intercept term */
78
public double intercept();
79
80
/** Get regularization parameter */
81
public double lambda();
82
}
83
84
/**
85
* LASSO regression with L1 regularization
86
*/
87
class LASSO implements Regression<double[]> {
88
/** Train LASSO regression */
89
public static LASSO fit(double[][] x, double[] y, double lambda);
90
91
/** Train with tolerance and max iterations */
92
public static LASSO fit(double[][] x, double[] y, double lambda, double tolerance, int maxIter);
93
94
/** Predict target value */
95
public double predict(double[] x);
96
97
/** Get model coefficients */
98
public double[] coefficients();
99
100
/** Get intercept term */
101
public double intercept();
102
103
/** Get L1 penalty parameter */
104
public double lambda();
105
}
106
107
/**
108
* Elastic Net regression combining L1 and L2 regularization
109
*/
110
class ElasticNet implements Regression<double[]> {
111
/** Train Elastic Net regression */
112
public static ElasticNet fit(double[][] x, double[] y, double lambda1, double lambda2);
113
114
/** Train with convergence parameters */
115
public static ElasticNet fit(double[][] x, double[] y, double lambda1, double lambda2, double tolerance, int maxIter);
116
117
/** Predict target value */
118
public double predict(double[] x);
119
120
/** Get model coefficients */
121
public double[] coefficients();
122
123
/** Get L1 penalty parameter */
124
public double lambda1();
125
126
/** Get L2 penalty parameter */
127
public double lambda2();
128
}
129
```
130
131
**Usage Example:**
132
133
```java
134
import smile.regression.*;
135
136
// Train linear models
137
OLS ols = OLS.fit(x, y);
138
RidgeRegression ridge = RidgeRegression.fit(x, y, 0.1);
139
LASSO lasso = LASSO.fit(x, y, 0.01);
140
141
// Make predictions
142
double prediction = ols.predict(newSample);
143
double ridgePred = ridge.predict(newSample);
144
145
// Get model statistics
146
double r2 = ols.RSquared();
147
double[] coeffs = ridge.coefficients();
148
```
149
150
### Tree-Based Regression
151
152
Regression algorithms based on decision trees and ensemble methods.
153
154
```java { .api }
155
/**
156
* Regression tree using CART algorithm
157
*/
158
class RegressionTree implements Regression<double[]>, DataFrameRegression {
159
/** Train regression tree */
160
public static RegressionTree fit(double[][] x, double[] y);
161
162
/** Train with formula on DataFrame */
163
public static RegressionTree fit(Formula formula, DataFrame data);
164
165
/** Train with custom parameters */
166
public static RegressionTree fit(double[][] x, double[] y, int maxDepth, int maxNodes, int nodeSize);
167
168
/** Predict target value */
169
public double predict(double[] x);
170
171
/** Get feature importance */
172
public double[] importance();
173
174
/** Get tree structure */
175
public String toString();
176
}
177
178
/**
179
* Random Forest regression
180
*/
181
class RandomForest implements Regression<double[]>, DataFrameRegression {
182
/** Train random forest regression */
183
public static RandomForest fit(double[][] x, double[] y);
184
185
/** Train with formula on DataFrame */
186
public static RandomForest fit(Formula formula, DataFrame data);
187
188
/** Train with custom parameters */
189
public static RandomForest fit(double[][] x, double[] y, int numTrees, int mtry, int maxDepth, int nodeSize);
190
191
/** Predict target value */
192
public double predict(double[] x);
193
194
/** Get out-of-bag RMSE */
195
public double error();
196
197
/** Get feature importance */
198
public double[] importance();
199
}
200
201
/**
202
* Gradient Tree Boosting regression
203
*/
204
class GradientTreeBoost implements Regression<double[]>, DataFrameRegression {
205
/** Train gradient boosting regression */
206
public static GradientTreeBoost fit(double[][] x, double[] y);
207
208
/** Train with formula on DataFrame */
209
public static GradientTreeBoost fit(Formula formula, DataFrame data);
210
211
/** Train with custom parameters */
212
public static GradientTreeBoost fit(double[][] x, double[] y, int numTrees, int maxDepth, double shrinkage, double subsample);
213
214
/** Predict target value */
215
public double predict(double[] x);
216
217
/** Get feature importance */
218
public double[] importance();
219
}
220
```
221
222
### Kernel Methods
223
224
Kernel-based regression algorithms including support vector regression and Gaussian processes.
225
226
```java { .api }
227
/**
228
* Support Vector Regression
229
*/
230
class SVM implements Regression<double[]> {
231
/** Train SVR with RBF kernel */
232
public static SVM fit(double[][] x, double[] y, double gamma, double C, double epsilon);
233
234
/** Train SVR with custom kernel */
235
public static SVM fit(double[][] x, double[] y, Kernel kernel, double C, double epsilon);
236
237
/** Predict target value */
238
public double predict(double[] x);
239
240
/** Get support vectors */
241
public SupportVector[] supportVectors();
242
243
/** Get number of support vectors */
244
public int numSupportVectors();
245
}
246
247
/**
248
* Gaussian Process Regression
249
*/
250
class GaussianProcessRegression implements Regression<double[]> {
251
/** Train Gaussian Process with RBF kernel */
252
public static GaussianProcessRegression fit(double[][] x, double[] y, double sigma);
253
254
/** Train with custom kernel and noise */
255
public static GaussianProcessRegression fit(double[][] x, double[] y, Kernel kernel, double noise);
256
257
/** Predict target value */
258
public double predict(double[] x);
259
260
/** Predict with uncertainty estimate */
261
public double predict(double[] x, double[] variance);
262
263
/** Get posterior mean function */
264
public double[] mean();
265
266
/** Get kernel function */
267
public Kernel kernel();
268
}
269
```
270
271
### Neural Network Regression
272
273
Multi-layer perceptron for regression tasks with configurable architecture.
274
275
```java { .api }
276
/**
277
* Multi-Layer Perceptron regression
278
*/
279
class MLP implements Regression<double[]> {
280
/** Train MLP regression */
281
public static MLP fit(double[][] x, double[] y);
282
283
/** Train with custom architecture */
284
public static MLP fit(double[][] x, double[] y, int[] hiddenLayers, ActivationFunction activation);
285
286
/** Train with full configuration */
287
public static MLP fit(double[][] x, double[] y, Properties params);
288
289
/** Predict target value */
290
public double predict(double[] x);
291
292
/** Online learning update */
293
public void update(double[] x, double y);
294
295
/** Get network weights for layer */
296
public double[][] getWeights(int layer);
297
}
298
```
299
300
### Radial Basis Function Networks
301
302
RBF networks for non-linear regression with localized activation functions.
303
304
```java { .api }
305
/**
306
* Radial Basis Function Network regression
307
*/
308
class RBFNetwork implements Regression<double[]> {
309
/** Train RBF network with Gaussian RBFs */
310
public static RBFNetwork fit(double[][] x, double[] y, int numCenters);
311
312
/** Train with custom RBF and centers */
313
public static RBFNetwork fit(double[][] x, double[] y, RBF rbf, double[][] centers);
314
315
/** Predict target value */
316
public double predict(double[] x);
317
318
/** Get RBF centers */
319
public double[][] centers();
320
321
/** Get output weights */
322
public double[] weights();
323
}
324
```
325
326
### Base Classes and Utilities
327
328
Abstract base classes and utility functions for regression algorithms.
329
330
```java { .api }
331
/**
332
* Base class for linear regression models
333
*/
334
abstract class LinearModel implements Regression<double[]> {
335
/** Model coefficients */
336
public abstract double[] coefficients();
337
338
/** Intercept term */
339
public abstract double intercept();
340
341
/** Predict using linear combination */
342
public double predict(double[] x);
343
}
344
345
/**
346
* Base class for kernel-based regression
347
*/
348
abstract class KernelMachine<T> implements Regression<T> {
349
/** Kernel function */
350
public abstract Kernel<T> kernel();
351
352
/** Support vectors or training instances */
353
public abstract T[] instances();
354
355
/** Instance weights */
356
public abstract double[] weights();
357
}
358
359
/**
360
* Interface for DataFrame-based regression
361
*/
362
interface DataFrameRegression {
363
/** Train with formula on DataFrame */
364
static Regression<Tuple> fit(Formula formula, DataFrame data);
365
}
366
```
367
368
### Generalized Linear Models
369
370
GLM framework for regression with various distribution families.
371
372
```java { .api }
373
/**
374
* Generalized Linear Model
375
*/
376
class GLM implements Regression<double[]> {
377
/** Train GLM with Gaussian family (linear regression) */
378
public static GLM fit(double[][] x, double[] y);
379
380
/** Train GLM with specified family and link */
381
public static GLM fit(double[][] x, double[] y, GLM.Family family, GLM.Link link);
382
383
/** Train with regularization */
384
public static GLM fit(double[][] x, double[] y, GLM.Family family, GLM.Link link, double lambda, double alpha);
385
386
/** Predict target value */
387
public double predict(double[] x);
388
389
/** Get model coefficients */
390
public double[] coefficients();
391
392
/** Get deviance */
393
public double deviance();
394
395
/** GLM distribution families */
396
enum Family { GAUSSIAN, BINOMIAL, POISSON, GAMMA }
397
398
/** GLM link functions */
399
enum Link { IDENTITY, LOG, LOGIT, INVERSE, SQRT }
400
}
401
```
402
403
**Usage Example:**
404
405
```java
406
import smile.regression.GaussianProcessRegression;
407
import smile.regression.RandomForest;
408
409
// Gaussian Process with uncertainty
410
GaussianProcessRegression gp = GaussianProcessRegression.fit(x, y, 1.0);
411
double[] variance = new double[1];
412
double prediction = gp.predict(newSample, variance);
413
double uncertainty = Math.sqrt(variance[0]);
414
415
// Random Forest ensemble
416
RandomForest rf = RandomForest.fit(x, y, 500, 5, 20, 5);
417
double prediction = rf.predict(newSample);
418
double oobError = rf.error();
419
```
420
421
### Training Patterns
422
423
All regression algorithms follow consistent training patterns:
424
425
**Array-based training:**
426
```java
427
Regression model = Algorithm.fit(double[][] x, double[] y);
428
Regression model = Algorithm.fit(double[][] x, double[] y, parameters...);
429
```
430
431
**DataFrame-based training:**
432
```java
433
Regression model = Algorithm.fit(Formula formula, DataFrame data);
434
```
435
436
**Prediction patterns:**
437
```java
438
double prediction = model.predict(double[] x);
439
double prediction = model.predict(double[] x, double[] uncertainty); // For probabilistic models
440
```
441
442
### Common Parameters
443
444
Most regression algorithms support these common configuration options:
445
446
- **lambda**: Regularization parameter (linear models)
447
- **alpha**: Elastic net mixing parameter
448
- **maxDepth**: Maximum tree depth (tree-based)
449
- **numTrees**: Number of trees in ensemble
450
- **nodeSize**: Minimum samples per leaf
451
- **shrinkage**: Learning rate (boosting)
452
- **subsample**: Fraction of samples for training
453
- **tolerance**: Convergence tolerance
454
- **maxIter**: Maximum iterations
455
- **seed**: Random seed for reproducibility