0
# Core Training Infrastructure
1
2
Central training and learning infrastructure that forms the foundation of all fastai workflows. The Learner class coordinates model training, data handling, optimization, and callbacks.
3
4
## Capabilities
5
6
### Main Learner Class
7
8
The central class for training models in fastai, managing the training loop, data, model, optimizer, and callbacks.
9
10
```python { .api }
11
class Learner:
12
"""
13
Central class for training models.
14
15
Parameters:
16
- dls: DataLoaders with training and validation data
17
- model: PyTorch model to train
18
- loss_func: Loss function (auto-inferred if None)
19
- opt_func: Optimizer constructor (default: Adam)
20
- lr: Learning rate (default: 0.001)
21
- metrics: List of metrics to track during training
22
- cbs: List of callbacks
23
- wd: Weight decay
24
"""
25
def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=0.001,
26
metrics=None, cbs=None, wd=None): ...
27
28
def fit(self, n_epoch, lr=None, wd=None, cbs=None):
29
"""
30
Train the model for n_epoch epochs.
31
32
Parameters:
33
- n_epoch: Number of epochs to train
34
- lr: Learning rate (uses learner default if None)
35
- wd: Weight decay (uses learner default if None)
36
- cbs: Additional callbacks for this training run
37
"""
38
39
def fine_tune(self, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
40
pct_start=0.3, div=5.0, **kwargs):
41
"""
42
Fine-tune a pre-trained model.
43
44
Parameters:
45
- epochs: Number of fine-tuning epochs
46
- base_lr: Base learning rate for fine-tuning
47
- freeze_epochs: Epochs to train with frozen body
48
- lr_mult: Learning rate multiplier for head vs body
49
- pct_start: Percentage of training for warmup
50
- div: Learning rate division factor
51
"""
52
53
def predict(self, item, with_input=False):
54
"""
55
Make prediction on a single item.
56
57
Parameters:
58
- item: Input item to predict on
59
- with_input: Whether to return processed input
60
61
Returns:
62
- Prediction class, prediction index, raw outputs
63
"""
64
65
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=True,
66
act=None, inner=False, reorder=True, cbs=None):
67
"""
68
Get predictions on a dataset.
69
70
Parameters:
71
- ds_idx: Dataset index (0=train, 1=valid)
72
- dl: DataLoader to use (uses learner's if None)
73
- with_input: Include processed inputs
74
- with_decoded: Include decoded predictions
75
- act: Activation function to apply
76
- inner: Return inner model outputs
77
- reorder: Reorder predictions to match original order
78
- cbs: Additional callbacks
79
80
Returns:
81
- Predictions, targets, (inputs), (decoded)
82
"""
83
84
def validate(self, ds_idx=1, dl=None, cbs=None):
85
"""
86
Validate the model on a dataset.
87
88
Parameters:
89
- ds_idx: Dataset index (0=train, 1=valid)
90
- dl: DataLoader to use
91
- cbs: Additional callbacks
92
93
Returns:
94
- Validation loss and metrics
95
"""
96
97
def lr_find(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp',
98
show_plot=True, suggest_funcs=(valley, slide)):
99
"""
100
Find optimal learning rate using learning rate range test.
101
102
Parameters:
103
- start_lr: Starting learning rate
104
- end_lr: Ending learning rate
105
- num_it: Number of iterations
106
- step_mode: 'exp' or 'linear' stepping
107
- show_plot: Display the learning rate plot
108
- suggest_funcs: Functions to suggest optimal LR
109
110
Returns:
111
- SuggestedLRs object with recommendations
112
"""
113
114
def freeze(self):
115
"""Freeze model body (typically pre-trained layers)."""
116
117
def unfreeze(self):
118
"""Unfreeze entire model for training."""
119
120
def save(self, file, with_opt=True, pickle_protocol=2):
121
"""
122
Save learner state.
123
124
Parameters:
125
- file: Filename to save to
126
- with_opt: Include optimizer state
127
- pickle_protocol: Pickle protocol version
128
"""
129
130
def load(self, file, with_opt=None, device=None, **kwargs):
131
"""
132
Load learner state.
133
134
Parameters:
135
- file: Filename to load from
136
- with_opt: Load optimizer state
137
- device: Device to load to
138
"""
139
140
def export(self, file='export.pkl', pickle_protocol=2):
141
"""Export learner for inference (without training state)."""
142
```
143
144
### Model Management
145
146
Functions for loading and saving models and learners.
147
148
```python { .api }
149
def load_learner(path, cpu=True, pickle_module=pickle, map_location=None, **kwargs):
150
"""
151
Load a saved learner from disk.
152
153
Parameters:
154
- path: Path to saved learner file
155
- cpu: Load on CPU regardless of original device
156
- pickle_module: Pickle module to use
157
- map_location: Device mapping for loading
158
159
Returns:
160
- Loaded Learner instance
161
"""
162
163
def save_model(file, model, opt, with_opt=True, pickle_protocol=2):
164
"""
165
Save model weights and optimizer state.
166
167
Parameters:
168
- file: Filename to save to
169
- model: PyTorch model
170
- opt: Optimizer
171
- with_opt: Include optimizer state
172
- pickle_protocol: Pickle protocol version
173
"""
174
175
def load_model(file, model, opt=None, with_opt=None, device=None, **kwargs):
176
"""
177
Load model weights and optimizer state.
178
179
Parameters:
180
- file: Filename to load from
181
- model: PyTorch model to load weights into
182
- opt: Optimizer to load state into
183
- with_opt: Load optimizer state
184
- device: Device to load to
185
"""
186
```
187
188
### Tensor and Array Base Classes
189
190
Core tensor classes that extend PyTorch tensors with fastai functionality.
191
192
```python { .api }
193
class TensorBase(Tensor):
194
"""Base class for fastai tensors with enhanced functionality."""
195
196
def __new__(cls, x, **kwargs): ...
197
def show(self, ctx=None, **kwargs): ...
198
199
class TensorImage(TensorBase):
200
"""Tensor subclass for image data."""
201
202
def show(self, ctx=None, **kwargs): ...
203
204
class TensorCategory(TensorBase):
205
"""Tensor subclass for categorical data."""
206
207
def show(self, ctx=None, **kwargs): ...
208
209
class TensorMultiCategory(TensorBase):
210
"""Tensor subclass for multi-label categorical data."""
211
212
def show(self, ctx=None, **kwargs): ...
213
214
class TensorMask(TensorBase):
215
"""Tensor subclass for segmentation masks."""
216
217
def show(self, ctx=None, **kwargs): ...
218
```
219
220
### Core Utilities
221
222
Essential utility functions for tensor operations and device management.
223
224
```python { .api }
225
def tensor(x, *rest, **kwargs):
226
"""
227
Enhanced tensor creation with automatic device handling.
228
229
Parameters:
230
- x: Data to convert to tensor
231
- dtype: Data type
232
- device: Device to place tensor on
233
234
Returns:
235
- Torch tensor
236
"""
237
238
def to_device(b, device=None):
239
"""Move tensor(s) to device."""
240
241
def to_cpu(b):
242
"""Move tensor(s) to CPU."""
243
244
def to_np(x):
245
"""Convert tensor to numpy array."""
246
247
def set_seed(s, reproducible=False):
248
"""
249
Set random seed for reproducibility.
250
251
Parameters:
252
- s: Random seed value
253
- reproducible: Enable deterministic algorithms
254
"""
255
256
def one_hot(x, c):
257
"""Convert to one-hot encoding."""
258
259
def one_hot_decode(x, vocab=None):
260
"""Decode one-hot encoding."""
261
```