or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-analytics.mdclassification.mdclustering.mddeep-learning.mdfeature-engineering.mdindex.mdregression.mdvalidation-metrics.md

regression.mddocs/

0

# Regression

1

2

Supervised learning algorithms for predicting continuous values. Smile Core provides comprehensive regression capabilities from traditional linear models to advanced ensemble methods, kernel machines, and neural networks.

3

4

## Capabilities

5

6

### Core Regression Interface

7

8

All regression algorithms implement the unified `Regression<T>` interface, providing consistent prediction methods and optional online learning support.

9

10

```java { .api }

11

/**

12

* Base regression interface for all supervised learning algorithms

13

* @param <T> the type of input objects

14

*/

15

interface Regression<T> extends ToDoubleFunction<T>, Serializable {

16

/** Predict the target value for input */

17

double predict(T x);

18

19

/** Online learning update (if supported) */

20

default void update(T x, double y);

21

22

/** Create ensemble of multiple regressors */

23

static <T> Regression<T> ensemble(Regression<T>... regressors);

24

}

25

```

26

27

### Linear Regression Models

28

29

Family of linear regression algorithms with various regularization techniques.

30

31

```java { .api }

32

/**

33

* Ordinary Least Squares regression

34

*/

35

class OLS implements Regression<double[]> {

36

/** Train OLS regression */

37

public static OLS fit(double[][] x, double[] y);

38

39

/** Train with intercept control */

40

public static OLS fit(double[][] x, double[] y, boolean intercept);

41

42

/** Predict target value */

43

public double predict(double[] x);

44

45

/** Get model coefficients */

46

public double[] coefficients();

47

48

/** Get intercept term */

49

public double intercept();

50

51

/** Get R-squared value */

52

public double RSquared();

53

54

/** Get adjusted R-squared */

55

public double adjustedRSquared();

56

57

/** Get residual sum of squares */

58

public double RSS();

59

60

/** Get total sum of squares */

61

public double TSS();

62

}

63

64

/**

65

* Ridge regression with L2 regularization

66

*/

67

class RidgeRegression implements Regression<double[]> {

68

/** Train ridge regression */

69

public static RidgeRegression fit(double[][] x, double[] y, double lambda);

70

71

/** Predict target value */

72

public double predict(double[] x);

73

74

/** Get model coefficients */

75

public double[] coefficients();

76

77

/** Get intercept term */

78

public double intercept();

79

80

/** Get regularization parameter */

81

public double lambda();

82

}

83

84

/**

85

* LASSO regression with L1 regularization

86

*/

87

class LASSO implements Regression<double[]> {

88

/** Train LASSO regression */

89

public static LASSO fit(double[][] x, double[] y, double lambda);

90

91

/** Train with tolerance and max iterations */

92

public static LASSO fit(double[][] x, double[] y, double lambda, double tolerance, int maxIter);

93

94

/** Predict target value */

95

public double predict(double[] x);

96

97

/** Get model coefficients */

98

public double[] coefficients();

99

100

/** Get intercept term */

101

public double intercept();

102

103

/** Get L1 penalty parameter */

104

public double lambda();

105

}

106

107

/**

108

* Elastic Net regression combining L1 and L2 regularization

109

*/

110

class ElasticNet implements Regression<double[]> {

111

/** Train Elastic Net regression */

112

public static ElasticNet fit(double[][] x, double[] y, double lambda1, double lambda2);

113

114

/** Train with convergence parameters */

115

public static ElasticNet fit(double[][] x, double[] y, double lambda1, double lambda2, double tolerance, int maxIter);

116

117

/** Predict target value */

118

public double predict(double[] x);

119

120

/** Get model coefficients */

121

public double[] coefficients();

122

123

/** Get L1 penalty parameter */

124

public double lambda1();

125

126

/** Get L2 penalty parameter */

127

public double lambda2();

128

}

129

```

130

131

**Usage Example:**

132

133

```java

134

import smile.regression.*;

135

136

// Train linear models

137

OLS ols = OLS.fit(x, y);

138

RidgeRegression ridge = RidgeRegression.fit(x, y, 0.1);

139

LASSO lasso = LASSO.fit(x, y, 0.01);

140

141

// Make predictions

142

double prediction = ols.predict(newSample);

143

double ridgePred = ridge.predict(newSample);

144

145

// Get model statistics

146

double r2 = ols.RSquared();

147

double[] coeffs = ridge.coefficients();

148

```

149

150

### Tree-Based Regression

151

152

Regression algorithms based on decision trees and ensemble methods.

153

154

```java { .api }

155

/**

156

* Regression tree using CART algorithm

157

*/

158

class RegressionTree implements Regression<double[]>, DataFrameRegression {

159

/** Train regression tree */

160

public static RegressionTree fit(double[][] x, double[] y);

161

162

/** Train with formula on DataFrame */

163

public static RegressionTree fit(Formula formula, DataFrame data);

164

165

/** Train with custom parameters */

166

public static RegressionTree fit(double[][] x, double[] y, int maxDepth, int maxNodes, int nodeSize);

167

168

/** Predict target value */

169

public double predict(double[] x);

170

171

/** Get feature importance */

172

public double[] importance();

173

174

/** Get tree structure */

175

public String toString();

176

}

177

178

/**

179

* Random Forest regression

180

*/

181

class RandomForest implements Regression<double[]>, DataFrameRegression {

182

/** Train random forest regression */

183

public static RandomForest fit(double[][] x, double[] y);

184

185

/** Train with formula on DataFrame */

186

public static RandomForest fit(Formula formula, DataFrame data);

187

188

/** Train with custom parameters */

189

public static RandomForest fit(double[][] x, double[] y, int numTrees, int mtry, int maxDepth, int nodeSize);

190

191

/** Predict target value */

192

public double predict(double[] x);

193

194

/** Get out-of-bag RMSE */

195

public double error();

196

197

/** Get feature importance */

198

public double[] importance();

199

}

200

201

/**

202

* Gradient Tree Boosting regression

203

*/

204

class GradientTreeBoost implements Regression<double[]>, DataFrameRegression {

205

/** Train gradient boosting regression */

206

public static GradientTreeBoost fit(double[][] x, double[] y);

207

208

/** Train with formula on DataFrame */

209

public static GradientTreeBoost fit(Formula formula, DataFrame data);

210

211

/** Train with custom parameters */

212

public static GradientTreeBoost fit(double[][] x, double[] y, int numTrees, int maxDepth, double shrinkage, double subsample);

213

214

/** Predict target value */

215

public double predict(double[] x);

216

217

/** Get feature importance */

218

public double[] importance();

219

}

220

```

221

222

### Kernel Methods

223

224

Kernel-based regression algorithms including support vector regression and Gaussian processes.

225

226

```java { .api }

227

/**

228

* Support Vector Regression

229

*/

230

class SVM implements Regression<double[]> {

231

/** Train SVR with RBF kernel */

232

public static SVM fit(double[][] x, double[] y, double gamma, double C, double epsilon);

233

234

/** Train SVR with custom kernel */

235

public static SVM fit(double[][] x, double[] y, Kernel kernel, double C, double epsilon);

236

237

/** Predict target value */

238

public double predict(double[] x);

239

240

/** Get support vectors */

241

public SupportVector[] supportVectors();

242

243

/** Get number of support vectors */

244

public int numSupportVectors();

245

}

246

247

/**

248

* Gaussian Process Regression

249

*/

250

class GaussianProcessRegression implements Regression<double[]> {

251

/** Train Gaussian Process with RBF kernel */

252

public static GaussianProcessRegression fit(double[][] x, double[] y, double sigma);

253

254

/** Train with custom kernel and noise */

255

public static GaussianProcessRegression fit(double[][] x, double[] y, Kernel kernel, double noise);

256

257

/** Predict target value */

258

public double predict(double[] x);

259

260

/** Predict with uncertainty estimate */

261

public double predict(double[] x, double[] variance);

262

263

/** Get posterior mean function */

264

public double[] mean();

265

266

/** Get kernel function */

267

public Kernel kernel();

268

}

269

```

270

271

### Neural Network Regression

272

273

Multi-layer perceptron for regression tasks with configurable architecture.

274

275

```java { .api }

276

/**

277

* Multi-Layer Perceptron regression

278

*/

279

class MLP implements Regression<double[]> {

280

/** Train MLP regression */

281

public static MLP fit(double[][] x, double[] y);

282

283

/** Train with custom architecture */

284

public static MLP fit(double[][] x, double[] y, int[] hiddenLayers, ActivationFunction activation);

285

286

/** Train with full configuration */

287

public static MLP fit(double[][] x, double[] y, Properties params);

288

289

/** Predict target value */

290

public double predict(double[] x);

291

292

/** Online learning update */

293

public void update(double[] x, double y);

294

295

/** Get network weights for layer */

296

public double[][] getWeights(int layer);

297

}

298

```

299

300

### Radial Basis Function Networks

301

302

RBF networks for non-linear regression with localized activation functions.

303

304

```java { .api }

305

/**

306

* Radial Basis Function Network regression

307

*/

308

class RBFNetwork implements Regression<double[]> {

309

/** Train RBF network with Gaussian RBFs */

310

public static RBFNetwork fit(double[][] x, double[] y, int numCenters);

311

312

/** Train with custom RBF and centers */

313

public static RBFNetwork fit(double[][] x, double[] y, RBF rbf, double[][] centers);

314

315

/** Predict target value */

316

public double predict(double[] x);

317

318

/** Get RBF centers */

319

public double[][] centers();

320

321

/** Get output weights */

322

public double[] weights();

323

}

324

```

325

326

### Base Classes and Utilities

327

328

Abstract base classes and utility functions for regression algorithms.

329

330

```java { .api }

331

/**

332

* Base class for linear regression models

333

*/

334

abstract class LinearModel implements Regression<double[]> {

335

/** Model coefficients */

336

public abstract double[] coefficients();

337

338

/** Intercept term */

339

public abstract double intercept();

340

341

/** Predict using linear combination */

342

public double predict(double[] x);

343

}

344

345

/**

346

* Base class for kernel-based regression

347

*/

348

abstract class KernelMachine<T> implements Regression<T> {

349

/** Kernel function */

350

public abstract Kernel<T> kernel();

351

352

/** Support vectors or training instances */

353

public abstract T[] instances();

354

355

/** Instance weights */

356

public abstract double[] weights();

357

}

358

359

/**

360

* Interface for DataFrame-based regression

361

*/

362

interface DataFrameRegression {

363

/** Train with formula on DataFrame */

364

static Regression<Tuple> fit(Formula formula, DataFrame data);

365

}

366

```

367

368

### Generalized Linear Models

369

370

GLM framework for regression with various distribution families.

371

372

```java { .api }

373

/**

374

* Generalized Linear Model

375

*/

376

class GLM implements Regression<double[]> {

377

/** Train GLM with Gaussian family (linear regression) */

378

public static GLM fit(double[][] x, double[] y);

379

380

/** Train GLM with specified family and link */

381

public static GLM fit(double[][] x, double[] y, GLM.Family family, GLM.Link link);

382

383

/** Train with regularization */

384

public static GLM fit(double[][] x, double[] y, GLM.Family family, GLM.Link link, double lambda, double alpha);

385

386

/** Predict target value */

387

public double predict(double[] x);

388

389

/** Get model coefficients */

390

public double[] coefficients();

391

392

/** Get deviance */

393

public double deviance();

394

395

/** GLM distribution families */

396

enum Family { GAUSSIAN, BINOMIAL, POISSON, GAMMA }

397

398

/** GLM link functions */

399

enum Link { IDENTITY, LOG, LOGIT, INVERSE, SQRT }

400

}

401

```

402

403

**Usage Example:**

404

405

```java

406

import smile.regression.GaussianProcessRegression;

407

import smile.regression.RandomForest;

408

409

// Gaussian Process with uncertainty

410

GaussianProcessRegression gp = GaussianProcessRegression.fit(x, y, 1.0);

411

double[] variance = new double[1];

412

double prediction = gp.predict(newSample, variance);

413

double uncertainty = Math.sqrt(variance[0]);

414

415

// Random Forest ensemble

416

RandomForest rf = RandomForest.fit(x, y, 500, 5, 20, 5);

417

double prediction = rf.predict(newSample);

418

double oobError = rf.error();

419

```

420

421

### Training Patterns

422

423

All regression algorithms follow consistent training patterns:

424

425

**Array-based training:**

426

```java

427

Regression model = Algorithm.fit(double[][] x, double[] y);

428

Regression model = Algorithm.fit(double[][] x, double[] y, parameters...);

429

```

430

431

**DataFrame-based training:**

432

```java

433

Regression model = Algorithm.fit(Formula formula, DataFrame data);

434

```

435

436

**Prediction patterns:**

437

```java

438

double prediction = model.predict(double[] x);

439

double prediction = model.predict(double[] x, double[] uncertainty); // For probabilistic models

440

```

441

442

### Common Parameters

443

444

Most regression algorithms support these common configuration options:

445

446

- **lambda**: Regularization parameter (linear models)

447

- **alpha**: Elastic net mixing parameter

448

- **maxDepth**: Maximum tree depth (tree-based)

449

- **numTrees**: Number of trees in ensemble

450

- **nodeSize**: Minimum samples per leaf

451

- **shrinkage**: Learning rate (boosting)

452

- **subsample**: Fraction of samples for training

453

- **tolerance**: Convergence tolerance

454

- **maxIter**: Maximum iterations

455

- **seed**: Random seed for reproducibility