fastai simplifies training fast and accurate neural nets using modern best practices
npx @tessl/cli install tessl/pypi-fastai@2.8.0A comprehensive deep learning library that simplifies training fast and accurate neural networks using modern best practices. Built on PyTorch, fastai provides high-level components that can quickly and easily provide state-of-the-art results in standard deep learning domains, and provides researchers with low-level components that can be mixed and matched to build new approaches.
pip install fastaiThe main import patterns for fastai depend on the domain:
For vision tasks:
from fastai.vision.all import *For text tasks:
from fastai.text.all import *For tabular tasks:
from fastai.tabular.all import *For basic functionality:
from fastai.basics import *For collaborative filtering:
from fastai.collab import *from fastai.vision.all import *
# Download a sample dataset
path = untar_data(URLs.PETS)
# Create a data loader for image classification
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path),
valid_pct=0.2, seed=42,
label_func=lambda x: x[0].isupper(),
item_tfms=Resize(224))
# Create a learner with a pre-trained model
learn = vision_learner(dls, resnet34, metrics=error_rate)
# Train the model
learn.fine_tune(4)
# Make predictions
pred_class, pred_idx, outputs = learn.predict(path/'images'/'test_image.jpg')fastai is built around several key architectural concepts:
The library follows a layered API design where high-level convenience functions build on lower-level flexible components, allowing both rapid prototyping and advanced customization.
Central training and learning infrastructure including the main Learner class, metrics, optimization, and model management utilities.
class Learner:
def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=0.001, **kwargs): ...
def fit(self, n_epoch, lr=None, wd=None, cbs=None): ...
def fine_tune(self, epochs, base_lr=2e-3, freeze_epochs=1, **kwargs): ...
def predict(self, item, with_input=False): ...
def load_learner(path, cpu=True): ...
def vision_learner(dls, arch, normalize=True, n_out=None, **kwargs): ...
def text_classifier_learner(dls, arch, seq_len=72, **kwargs): ...
def tabular_learner(dls, layers=None, emb_szs=None, **kwargs): ...Comprehensive data loading system with the DataBlock API, transforms, and domain-specific data loaders for flexible data pipeline construction.
class DataLoaders:
def __init__(self, *loaders): ...
@classmethod
def from_dblock(cls, dblock, source, **kwargs): ...
class DataBlock:
def __init__(self, blocks=None, dl_type=None, getters=None, n_inp=None, **kwargs): ...
def dataloaders(self, source, **kwargs): ...
def ImageDataLoaders.from_folder(path, valid_pct=0.2, **kwargs): ...
def TextDataLoaders.from_folder(path, valid='valid', **kwargs): ...
def TabularDataLoaders.from_csv(path, y_names, **kwargs): ...Complete computer vision toolkit including pre-trained models, data augmentation, specialized learners for classification and segmentation, and vision-specific utilities.
def vision_learner(dls, arch, normalize=True, n_out=None, **kwargs): ...
def unet_learner(dls, arch, normalize=True, **kwargs): ...
class ImageDataLoaders:
@classmethod
def from_folder(cls, path, train='train', valid='valid', **kwargs): ...
@classmethod
def from_name_func(cls, path, fnames, label_func, **kwargs): ...
def aug_transforms(mult=1.0, do_flip=True, flip_vert=False, **kwargs): ...Text processing and NLP capabilities including language models, text classification, tokenization, and text-specific data processing.
def language_model_learner(dls, arch, config=None, **kwargs): ...
def text_classifier_learner(dls, arch, seq_len=72, **kwargs): ...
class TextDataLoaders:
@classmethod
def from_folder(cls, path, valid='valid', **kwargs): ...
@classmethod
def from_csv(cls, path, text_col='text', label_col='label', **kwargs): ...
class WordTokenizer: ...
class SubwordTokenizer: ...Tabular data processing and modeling including preprocessing transforms, neural network architectures designed for structured data, and tabular-specific utilities.
def tabular_learner(dls, layers=None, emb_szs=None, n_out=None, **kwargs): ...
class TabularDataLoaders:
@classmethod
def from_csv(cls, path, y_names, cat_names=None, cont_names=None, **kwargs): ...
class Categorify: ...
class FillMissing: ...
class Normalize: ...Recommendation system capabilities including specialized learners and models for collaborative filtering tasks.
def collab_learner(dls, n_factors=50, **kwargs): ...
class CollabDataLoaders:
@classmethod
def from_csv(cls, path, user_name=None, item_name=None, **kwargs): ...
class EmbeddingDotBias: ...Extensive callback system for customizing the training loop including progress tracking, learning rate scheduling, regularization, and logging.
class Callback:
def before_fit(self): ...
def before_epoch(self): ...
def before_batch(self): ...
class MixedPrecision(Callback): ...
class OneCycleTraining(Callback): ...
class EarlyStoppingCallback(Callback): ...
class SaveModelCallback(Callback): ...Comprehensive metrics for evaluating model performance and loss functions for training across different domains and tasks.
def accuracy(inp, targ): ...
def error_rate(inp, targ): ...
def top_k_accuracy(inp, targ, k=5): ...
class CrossEntropyLossFlat: ...
class MSELossFlat: ...
class FocalLoss: ...Tools for understanding and interpreting model predictions including visualization utilities and analysis methods.
class ClassificationInterpretation:
@classmethod
def from_learner(cls, learn, **kwargs): ...
def plot_confusion_matrix(self, **kwargs): ...
def plot_top_losses(self, k, **kwargs): ...Specialized tools for working with DICOM medical imaging files including CT scans, MRI, X-rays, and other medical imaging modalities with proper windowing, normalization, and processing.
def get_dicom_files(path, recurse=True, folders=None): ...
def dcmread(fn, force=False): ...
class DicomSegmentationDataLoaders(DataLoaders):
@classmethod
def from_label_func(cls, path, fnames, label_func, **kwargs): ...
class TensorDicom(TensorImage): ...
class PILDicom(PILBase): ...
# Predefined medical windows
dicom_windows = SimpleNamespace(
brain=(80,40), subdural=(254,100), stroke=(8,32),
brain_bone=(2800,600), lungs=(1500,-600), liver=(150,30)
)