fastai simplifies training fast and accurate neural nets using modern best practices
—
Central training and learning infrastructure that forms the foundation of all fastai workflows. The Learner class coordinates model training, data handling, optimization, and callbacks.
The central class for training models in fastai, managing the training loop, data, model, optimizer, and callbacks.
class Learner:
"""
Central class for training models.
Parameters:
- dls: DataLoaders with training and validation data
- model: PyTorch model to train
- loss_func: Loss function (auto-inferred if None)
- opt_func: Optimizer constructor (default: Adam)
- lr: Learning rate (default: 0.001)
- metrics: List of metrics to track during training
- cbs: List of callbacks
- wd: Weight decay
"""
def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=0.001,
metrics=None, cbs=None, wd=None): ...
def fit(self, n_epoch, lr=None, wd=None, cbs=None):
"""
Train the model for n_epoch epochs.
Parameters:
- n_epoch: Number of epochs to train
- lr: Learning rate (uses learner default if None)
- wd: Weight decay (uses learner default if None)
- cbs: Additional callbacks for this training run
"""
def fine_tune(self, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
pct_start=0.3, div=5.0, **kwargs):
"""
Fine-tune a pre-trained model.
Parameters:
- epochs: Number of fine-tuning epochs
- base_lr: Base learning rate for fine-tuning
- freeze_epochs: Epochs to train with frozen body
- lr_mult: Learning rate multiplier for head vs body
- pct_start: Percentage of training for warmup
- div: Learning rate division factor
"""
def predict(self, item, with_input=False):
"""
Make prediction on a single item.
Parameters:
- item: Input item to predict on
- with_input: Whether to return processed input
Returns:
- Prediction class, prediction index, raw outputs
"""
def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=True,
act=None, inner=False, reorder=True, cbs=None):
"""
Get predictions on a dataset.
Parameters:
- ds_idx: Dataset index (0=train, 1=valid)
- dl: DataLoader to use (uses learner's if None)
- with_input: Include processed inputs
- with_decoded: Include decoded predictions
- act: Activation function to apply
- inner: Return inner model outputs
- reorder: Reorder predictions to match original order
- cbs: Additional callbacks
Returns:
- Predictions, targets, (inputs), (decoded)
"""
def validate(self, ds_idx=1, dl=None, cbs=None):
"""
Validate the model on a dataset.
Parameters:
- ds_idx: Dataset index (0=train, 1=valid)
- dl: DataLoader to use
- cbs: Additional callbacks
Returns:
- Validation loss and metrics
"""
def lr_find(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp',
show_plot=True, suggest_funcs=(valley, slide)):
"""
Find optimal learning rate using learning rate range test.
Parameters:
- start_lr: Starting learning rate
- end_lr: Ending learning rate
- num_it: Number of iterations
- step_mode: 'exp' or 'linear' stepping
- show_plot: Display the learning rate plot
- suggest_funcs: Functions to suggest optimal LR
Returns:
- SuggestedLRs object with recommendations
"""
def freeze(self):
"""Freeze model body (typically pre-trained layers)."""
def unfreeze(self):
"""Unfreeze entire model for training."""
def save(self, file, with_opt=True, pickle_protocol=2):
"""
Save learner state.
Parameters:
- file: Filename to save to
- with_opt: Include optimizer state
- pickle_protocol: Pickle protocol version
"""
def load(self, file, with_opt=None, device=None, **kwargs):
"""
Load learner state.
Parameters:
- file: Filename to load from
- with_opt: Load optimizer state
- device: Device to load to
"""
def export(self, file='export.pkl', pickle_protocol=2):
"""Export learner for inference (without training state)."""Functions for loading and saving models and learners.
def load_learner(path, cpu=True, pickle_module=pickle, map_location=None, **kwargs):
"""
Load a saved learner from disk.
Parameters:
- path: Path to saved learner file
- cpu: Load on CPU regardless of original device
- pickle_module: Pickle module to use
- map_location: Device mapping for loading
Returns:
- Loaded Learner instance
"""
def save_model(file, model, opt, with_opt=True, pickle_protocol=2):
"""
Save model weights and optimizer state.
Parameters:
- file: Filename to save to
- model: PyTorch model
- opt: Optimizer
- with_opt: Include optimizer state
- pickle_protocol: Pickle protocol version
"""
def load_model(file, model, opt=None, with_opt=None, device=None, **kwargs):
"""
Load model weights and optimizer state.
Parameters:
- file: Filename to load from
- model: PyTorch model to load weights into
- opt: Optimizer to load state into
- with_opt: Load optimizer state
- device: Device to load to
"""Core tensor classes that extend PyTorch tensors with fastai functionality.
class TensorBase(Tensor):
"""Base class for fastai tensors with enhanced functionality."""
def __new__(cls, x, **kwargs): ...
def show(self, ctx=None, **kwargs): ...
class TensorImage(TensorBase):
"""Tensor subclass for image data."""
def show(self, ctx=None, **kwargs): ...
class TensorCategory(TensorBase):
"""Tensor subclass for categorical data."""
def show(self, ctx=None, **kwargs): ...
class TensorMultiCategory(TensorBase):
"""Tensor subclass for multi-label categorical data."""
def show(self, ctx=None, **kwargs): ...
class TensorMask(TensorBase):
"""Tensor subclass for segmentation masks."""
def show(self, ctx=None, **kwargs): ...Essential utility functions for tensor operations and device management.
def tensor(x, *rest, **kwargs):
"""
Enhanced tensor creation with automatic device handling.
Parameters:
- x: Data to convert to tensor
- dtype: Data type
- device: Device to place tensor on
Returns:
- Torch tensor
"""
def to_device(b, device=None):
"""Move tensor(s) to device."""
def to_cpu(b):
"""Move tensor(s) to CPU."""
def to_np(x):
"""Convert tensor to numpy array."""
def set_seed(s, reproducible=False):
"""
Set random seed for reproducibility.
Parameters:
- s: Random seed value
- reproducible: Enable deterministic algorithms
"""
def one_hot(x, c):
"""Convert to one-hot encoding."""
def one_hot_decode(x, vocab=None):
"""Decode one-hot encoding."""Install with Tessl CLI
npx tessl i tessl/pypi-fastai