or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcollaborative-filtering.mdcore-training.mddata-loading.mdindex.mdinterpretation.mdmedical.mdmetrics-losses.mdtabular.mdtext.mdvision.md

core-training.mddocs/

0

# Core Training Infrastructure

1

2

Central training and learning infrastructure that forms the foundation of all fastai workflows. The Learner class coordinates model training, data handling, optimization, and callbacks.

3

4

## Capabilities

5

6

### Main Learner Class

7

8

The central class for training models in fastai, managing the training loop, data, model, optimizer, and callbacks.

9

10

```python { .api }

11

class Learner:

12

"""

13

Central class for training models.

14

15

Parameters:

16

- dls: DataLoaders with training and validation data

17

- model: PyTorch model to train

18

- loss_func: Loss function (auto-inferred if None)

19

- opt_func: Optimizer constructor (default: Adam)

20

- lr: Learning rate (default: 0.001)

21

- metrics: List of metrics to track during training

22

- cbs: List of callbacks

23

- wd: Weight decay

24

"""

25

def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=0.001,

26

metrics=None, cbs=None, wd=None): ...

27

28

def fit(self, n_epoch, lr=None, wd=None, cbs=None):

29

"""

30

Train the model for n_epoch epochs.

31

32

Parameters:

33

- n_epoch: Number of epochs to train

34

- lr: Learning rate (uses learner default if None)

35

- wd: Weight decay (uses learner default if None)

36

- cbs: Additional callbacks for this training run

37

"""

38

39

def fine_tune(self, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,

40

pct_start=0.3, div=5.0, **kwargs):

41

"""

42

Fine-tune a pre-trained model.

43

44

Parameters:

45

- epochs: Number of fine-tuning epochs

46

- base_lr: Base learning rate for fine-tuning

47

- freeze_epochs: Epochs to train with frozen body

48

- lr_mult: Learning rate multiplier for head vs body

49

- pct_start: Percentage of training for warmup

50

- div: Learning rate division factor

51

"""

52

53

def predict(self, item, with_input=False):

54

"""

55

Make prediction on a single item.

56

57

Parameters:

58

- item: Input item to predict on

59

- with_input: Whether to return processed input

60

61

Returns:

62

- Prediction class, prediction index, raw outputs

63

"""

64

65

def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=True,

66

act=None, inner=False, reorder=True, cbs=None):

67

"""

68

Get predictions on a dataset.

69

70

Parameters:

71

- ds_idx: Dataset index (0=train, 1=valid)

72

- dl: DataLoader to use (uses learner's if None)

73

- with_input: Include processed inputs

74

- with_decoded: Include decoded predictions

75

- act: Activation function to apply

76

- inner: Return inner model outputs

77

- reorder: Reorder predictions to match original order

78

- cbs: Additional callbacks

79

80

Returns:

81

- Predictions, targets, (inputs), (decoded)

82

"""

83

84

def validate(self, ds_idx=1, dl=None, cbs=None):

85

"""

86

Validate the model on a dataset.

87

88

Parameters:

89

- ds_idx: Dataset index (0=train, 1=valid)

90

- dl: DataLoader to use

91

- cbs: Additional callbacks

92

93

Returns:

94

- Validation loss and metrics

95

"""

96

97

def lr_find(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp',

98

show_plot=True, suggest_funcs=(valley, slide)):

99

"""

100

Find optimal learning rate using learning rate range test.

101

102

Parameters:

103

- start_lr: Starting learning rate

104

- end_lr: Ending learning rate

105

- num_it: Number of iterations

106

- step_mode: 'exp' or 'linear' stepping

107

- show_plot: Display the learning rate plot

108

- suggest_funcs: Functions to suggest optimal LR

109

110

Returns:

111

- SuggestedLRs object with recommendations

112

"""

113

114

def freeze(self):

115

"""Freeze model body (typically pre-trained layers)."""

116

117

def unfreeze(self):

118

"""Unfreeze entire model for training."""

119

120

def save(self, file, with_opt=True, pickle_protocol=2):

121

"""

122

Save learner state.

123

124

Parameters:

125

- file: Filename to save to

126

- with_opt: Include optimizer state

127

- pickle_protocol: Pickle protocol version

128

"""

129

130

def load(self, file, with_opt=None, device=None, **kwargs):

131

"""

132

Load learner state.

133

134

Parameters:

135

- file: Filename to load from

136

- with_opt: Load optimizer state

137

- device: Device to load to

138

"""

139

140

def export(self, file='export.pkl', pickle_protocol=2):

141

"""Export learner for inference (without training state)."""

142

```

143

144

### Model Management

145

146

Functions for loading and saving models and learners.

147

148

```python { .api }

149

def load_learner(path, cpu=True, pickle_module=pickle, map_location=None, **kwargs):

150

"""

151

Load a saved learner from disk.

152

153

Parameters:

154

- path: Path to saved learner file

155

- cpu: Load on CPU regardless of original device

156

- pickle_module: Pickle module to use

157

- map_location: Device mapping for loading

158

159

Returns:

160

- Loaded Learner instance

161

"""

162

163

def save_model(file, model, opt, with_opt=True, pickle_protocol=2):

164

"""

165

Save model weights and optimizer state.

166

167

Parameters:

168

- file: Filename to save to

169

- model: PyTorch model

170

- opt: Optimizer

171

- with_opt: Include optimizer state

172

- pickle_protocol: Pickle protocol version

173

"""

174

175

def load_model(file, model, opt=None, with_opt=None, device=None, **kwargs):

176

"""

177

Load model weights and optimizer state.

178

179

Parameters:

180

- file: Filename to load from

181

- model: PyTorch model to load weights into

182

- opt: Optimizer to load state into

183

- with_opt: Load optimizer state

184

- device: Device to load to

185

"""

186

```

187

188

### Tensor and Array Base Classes

189

190

Core tensor classes that extend PyTorch tensors with fastai functionality.

191

192

```python { .api }

193

class TensorBase(Tensor):

194

"""Base class for fastai tensors with enhanced functionality."""

195

196

def __new__(cls, x, **kwargs): ...

197

def show(self, ctx=None, **kwargs): ...

198

199

class TensorImage(TensorBase):

200

"""Tensor subclass for image data."""

201

202

def show(self, ctx=None, **kwargs): ...

203

204

class TensorCategory(TensorBase):

205

"""Tensor subclass for categorical data."""

206

207

def show(self, ctx=None, **kwargs): ...

208

209

class TensorMultiCategory(TensorBase):

210

"""Tensor subclass for multi-label categorical data."""

211

212

def show(self, ctx=None, **kwargs): ...

213

214

class TensorMask(TensorBase):

215

"""Tensor subclass for segmentation masks."""

216

217

def show(self, ctx=None, **kwargs): ...

218

```

219

220

### Core Utilities

221

222

Essential utility functions for tensor operations and device management.

223

224

```python { .api }

225

def tensor(x, *rest, **kwargs):

226

"""

227

Enhanced tensor creation with automatic device handling.

228

229

Parameters:

230

- x: Data to convert to tensor

231

- dtype: Data type

232

- device: Device to place tensor on

233

234

Returns:

235

- Torch tensor

236

"""

237

238

def to_device(b, device=None):

239

"""Move tensor(s) to device."""

240

241

def to_cpu(b):

242

"""Move tensor(s) to CPU."""

243

244

def to_np(x):

245

"""Convert tensor to numpy array."""

246

247

def set_seed(s, reproducible=False):

248

"""

249

Set random seed for reproducibility.

250

251

Parameters:

252

- s: Random seed value

253

- reproducible: Enable deterministic algorithms

254

"""

255

256

def one_hot(x, c):

257

"""Convert to one-hot encoding."""

258

259

def one_hot_decode(x, vocab=None):

260

"""Decode one-hot encoding."""

261

```