fastai simplifies training fast and accurate neural nets using modern best practices
—
Complete computer vision toolkit for image classification, segmentation, object detection, and more. Built on top of the core fastai infrastructure with domain-specific optimizations.
Main entry points for creating vision models with pre-trained architectures and domain-specific optimizations.
def vision_learner(dls, arch, normalize=True, n_out=None, pretrained=True,
cut=None, splitter=None, y_range=None, config=None,
loss_func=None, opt_func=Adam, lr=defaults.lr, metrics=None,
**kwargs):
"""
Create a computer vision learner.
Parameters:
- dls: DataLoaders with image data
- arch: Model architecture (resnet34, efficientnet_b0, etc.)
- normalize: Apply ImageNet normalization
- n_out: Number of outputs (auto-detected from data if None)
- pretrained: Use pre-trained weights
- cut: Where to cut the pre-trained model
- splitter: Function to split model for differential learning rates
- y_range: Range of target values for regression
- config: Model configuration
- loss_func: Loss function (auto-selected if None)
- opt_func: Optimizer constructor
- lr: Learning rate
- metrics: Metrics to track
Returns:
- Learner instance configured for vision tasks
"""
def cnn_learner(dls, arch, **kwargs):
"""Deprecated alias for vision_learner."""
def unet_learner(dls, arch, normalize=True, n_out=None, img_size=None,
pretrained=True, cut=None, splitter=None, y_range=None,
config=None, loss_func=None, opt_func=Adam, lr=defaults.lr,
metrics=None, **kwargs):
"""
Create a U-Net learner for segmentation.
Parameters:
- dls: DataLoaders with image and mask data
- arch: Encoder architecture (resnet34, etc.)
- normalize: Apply ImageNet normalization
- n_out: Number of output classes
- img_size: Input image size
- pretrained: Use pre-trained encoder
- cut: Where to cut the encoder
- splitter: Function to split model layers
- y_range: Range for regression outputs
- config: Model configuration
- loss_func: Loss function (typically CrossEntropyLoss)
- opt_func: Optimizer constructor
- lr: Learning rate
- metrics: Metrics to track (Dice, IoU, etc.)
Returns:
- Learner instance for segmentation
"""
def create_vision_model(arch, n_out=1000, pretrained=True, cut=None, **kwargs):
"""Create vision model without learner wrapper."""
def create_unet_model(arch, n_out, img_size=None, pretrained=True, cut=None, **kwargs):
"""Create U-Net model without learner wrapper."""Specialized DataLoaders for common computer vision tasks.
class ImageDataLoaders(DataLoaders):
"""DataLoaders for image datasets."""
@classmethod
def from_folder(cls, path, train='train', valid='valid', valid_pct=None,
seed=None, vocab=None, item_tfms=None, batch_tfms=None,
img_cls=PILImage, **kwargs):
"""
Create ImageDataLoaders from folder structure.
Parameters:
- path: Path to data directory
- train: Training folder name
- valid: Validation folder name
- valid_pct: Validation percentage (if no valid folder)
- seed: Random seed for splitting
- vocab: Category vocabulary
- item_tfms: Item-level transforms
- batch_tfms: Batch-level transforms
- img_cls: Image class to use
Returns:
- ImageDataLoaders instance
"""
@classmethod
def from_name_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None,
item_tfms=None, batch_tfms=None, **kwargs):
"""
Create ImageDataLoaders using filename labeling function.
Parameters:
- path: Path to images
- fnames: List of filenames
- label_func: Function to extract labels from filenames
- valid_pct: Validation percentage
- seed: Random seed
- item_tfms: Item transforms
- batch_tfms: Batch transforms
Returns:
- ImageDataLoaders instance
"""
@classmethod
def from_name_re(cls, path, fnames, pat, valid_pct=0.2, **kwargs):
"""Create using regex pattern for labels."""
@classmethod
def from_path_func(cls, path, fnames, label_func, valid_pct=0.2, **kwargs):
"""Create using path-based labeling function."""
@classmethod
def from_path_re(cls, path, fnames, pat, valid_pct=0.2, **kwargs):
"""Create using regex pattern on full paths."""
@classmethod
def from_lists(cls, path, fnames, labels, valid_pct=0.2, **kwargs):
"""Create from filename and label lists."""
@classmethod
def from_csv(cls, path, csv_fname, header='infer', delimiter=None, **kwargs):
"""Create from CSV file."""
@classmethod
def from_df(cls, df, path='.', valid_pct=0.2, **kwargs):
"""Create from pandas DataFrame."""Core image classes for handling different image types.
class PILImage(PILBase):
"""PIL Image wrapper with fastai functionality."""
@classmethod
def create(cls, fn:(Path,str,Tensor,ndarray,bytes), **kwargs):
"""Create PILImage from various sources."""
def show(self, ctx=None, figsize=None, title=None, **kwargs):
"""Display the image."""
def to_thumb(self, h, w=None):
"""Create thumbnail."""
@property
def shape(self):
"""Image shape as (height, width, channels)."""
class PILImageBW(PILImage):
"""PIL Image wrapper for grayscale images."""
_show_args = {'cmap': 'gray'}
class PILMask(PILImage):
"""PIL Image wrapper for segmentation masks."""
_show_args = {'cmap': 'tab20', 'alpha': 0.5, 'vmin': 0, 'vmax': 20}
def show(self, ctx=None, figsize=None, title=None, **kwargs):
"""Display mask with color mapping."""Comprehensive augmentation pipeline for robust model training.
def aug_transforms(mult=1.0, do_flip=True, flip_vert=False, max_rotate=10.0,
min_zoom=1.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2,
p_affine=0.75, p_lighting=0.75, xtra_tfms=None, size=None,
mode='bilinear', pad_mode='reflection', align_corners=True,
batch=False, min_scale=1.0):
"""
Standard set of augmentation transforms.
Parameters:
- mult: Multiplier for augmentation strength
- do_flip: Enable horizontal flips
- flip_vert: Enable vertical flips
- max_rotate: Maximum rotation degrees
- min_zoom: Minimum zoom factor
- max_zoom: Maximum zoom factor
- max_lighting: Maximum lighting change
- max_warp: Maximum perspective warp
- p_affine: Probability of affine transforms
- p_lighting: Probability of lighting transforms
- xtra_tfms: Additional transforms
- size: Target size for transforms
- mode: Interpolation mode
- pad_mode: Padding mode
- align_corners: Align corners in resize
- batch: Apply at batch level
- min_scale: Minimum scale factor
Returns:
- List of transform objects
"""
class Resize(RandTransform):
"""Resize images to specified size."""
def __init__(self, size, method=ResizeMethod.Crop, pad_mode=PadMode.Reflection,
resamples=(Image.BILINEAR, Image.NEAREST), **kwargs): ...
class RandomResizedCrop(RandTransform):
"""Random crop with resize (like ImageNet training)."""
def __init__(self, size, min_scale=0.08, ratio=(3/4, 4/3), resamples=(Image.BILINEAR, Image.NEAREST),
val_xtra=0.14, **kwargs): ...
class CropPad(Transform):
"""Crop or pad to specified size."""
def __init__(self, size, pad_mode=PadMode.Reflection, **kwargs): ...
class FlipItem(RandTransform):
"""Random horizontal/vertical flips."""
def __init__(self, p=0.5): ...
class DihedralItem(RandTransform):
"""Random 90-degree rotations and flips."""
def __init__(self, p=0.5): ...
class Brightness(RandTransform):
"""Random brightness adjustment."""
def __init__(self, max_lighting=0.2, p=0.75, draw=None, batch=False): ...
class Contrast(RandTransform):
"""Random contrast adjustment."""
def __init__(self, max_lighting=0.2, p=0.75, draw=None, batch=False): ...
class Saturation(RandTransform):
"""Random saturation adjustment."""
def __init__(self, max_lighting=0.2, p=0.75, draw=None, batch=False): ...
class Hue(RandTransform):
"""Random hue shift."""
def __init__(self, max_hue=0.1, p=0.75, draw=None, batch=False): ...
class Cutout(RandTransform):
"""Random rectangular occlusion."""
def __init__(self, n_holes=1, length=40, p=0.5): ...
class RandomErasing(RandTransform):
"""Random erasing augmentation."""
def __init__(self, p=0.5, sh=0.4, min_aspect=0.3, max_count=1): ...Access to various pre-trained model architectures optimized for different tasks.
def xresnet18(pretrained=False, **kwargs):
"""XResNet-18 architecture."""
def xresnet34(pretrained=False, **kwargs):
"""XResNet-34 architecture."""
def xresnet50(pretrained=False, **kwargs):
"""XResNet-50 architecture."""
def xresnet101(pretrained=False, **kwargs):
"""XResNet-101 architecture."""
def xresnet152(pretrained=False, **kwargs):
"""XResNet-152 architecture."""
class XResNet(nn.Sequential):
"""Configurable XResNet architecture."""
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, act_cls=defaults.activation, **kwargs): ...
class DynamicUnet(SequentialEx):
"""Dynamic U-Net for segmentation."""
def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True,
self_attention=False, y_range=None, last_cross=True,
bottle=False, **kwargs): ...
class TimmBody(nn.Module):
"""Body using timm models."""
def __init__(self, arch, pretrained=True, cut=None, n_in=3): ...Utility functions for computer vision tasks.
# Normalization statistics for common datasets
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
cifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
mnist_stats = ([0.131], [0.308])
def download_images(dest, urls, max_pics=1000, n_workers=8, timeout=4):
"""
Download images from URLs.
Parameters:
- dest: Destination directory
- urls: List of image URLs
- max_pics: Maximum images to download
- n_workers: Number of worker threads
- timeout: Download timeout
"""
def verify_images(fns):
"""
Verify that image files are valid.
Parameters:
- fns: List of image filenames
Returns:
- List of failed filenames
"""
def show_image(im, ax=None, figsize=None, title=None, ctx=None, **kwargs):
"""Display single image."""
def show_images(ims, nrows=1, ncols=None, titles=None, figsize=None, **kwargs):
"""Display multiple images in grid."""
def subplots(nrows=1, ncols=1, figsize=None, imsize=3, add_vert=0, **kwargs):
"""Create matplotlib subplots with fastai styling."""Specialized functionality for image segmentation tasks.
class SegmentationDataLoaders(DataLoaders):
"""DataLoaders for segmentation tasks."""
@classmethod
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, **kwargs):
"""Create from labeling function that returns mask paths."""
class MaskBlock(TransformBlock):
"""Transform block for segmentation masks."""
def __init__(self, codes=None): ...
def DiceLoss():
"""Dice loss for segmentation."""
def JaccardLoss():
"""Jaccard (IoU) loss for segmentation."""
def FocalLoss(alpha=1, gamma=2):
"""Focal loss for handling class imbalance."""Install with Tessl CLI
npx tessl i tessl/pypi-fastai