State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Advanced optimization techniques including quantization, mixed precision training, hardware acceleration, and memory efficiency improvements for both inference and training workflows.
Reduce model memory footprint and increase inference speed through various quantization techniques.
class BitsAndBytesConfig:
def __init__(
self,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
llm_int8_threshold: float = 6.0,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
bnb_4bit_compute_dtype: Optional[torch.dtype] = None,
bnb_4bit_quant_type: str = "fp4",
bnb_4bit_use_double_quant: bool = False,
bnb_4bit_quant_storage: Optional[torch.dtype] = None,
**kwargs
):
"""
Configuration for BitsAndBytes quantization.
Args:
load_in_8bit: Enable 8-bit quantization
load_in_4bit: Enable 4-bit quantization
llm_int8_threshold: Threshold for outlier detection
llm_int8_skip_modules: Modules to skip quantization
llm_int8_enable_fp32_cpu_offload: Offload fp32 weights to CPU
llm_int8_has_fp16_weight: Model has fp16 weights
bnb_4bit_compute_dtype: Compute dtype for 4-bit
bnb_4bit_quant_type: Quantization type ("fp4", "nf4")
bnb_4bit_use_double_quant: Use double quantization
bnb_4bit_quant_storage: Storage dtype for quantized weights
"""
class GPTQConfig:
def __init__(
self,
bits: int = 4,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
dataset: Optional[Union[str, List[str]]] = None,
group_size: int = 128,
damp_percent: float = 0.1,
desc_act: bool = False,
static_groups: bool = False,
sym: bool = True,
true_sequential: bool = True,
model_name_or_path: Optional[str] = None,
model_seqlen: Optional[int] = None,
block_name_to_quantize: Optional[str] = None,
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
use_exllama: Optional[bool] = None,
max_input_length: Optional[int] = None,
exllama_config: Optional[Dict[str, Any]] = None,
cache_block_outputs: bool = True,
modules_in_block_to_quantize: Optional[List[List[str]]] = None,
**kwargs
):
"""
Configuration for GPTQ quantization.
Args:
bits: Number of bits for quantization
tokenizer: Tokenizer for calibration dataset
dataset: Calibration dataset name or samples
group_size: Group size for quantization
damp_percent: Damping percentage
desc_act: Use descending activation order
static_groups: Use static groups
sym: Use symmetric quantization
true_sequential: Use true sequential quantization
model_seqlen: Model sequence length
batch_size: Batch size for calibration
use_exllama: Use ExLlama kernels
max_input_length: Maximum input length
"""
class AwqConfig:
def __init__(
self,
bits: int = 4,
group_size: int = 128,
zero_point: bool = True,
version: str = "GEMM",
backend: str = "autoawq",
do_fuse: Optional[bool] = None,
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[Dict] = None,
**kwargs
):
"""
Configuration for AWQ quantization.
Args:
bits: Number of bits for quantization
group_size: Group size for quantization
zero_point: Use zero point quantization
version: AWQ version ("GEMM", "GEMV")
backend: Backend implementation
do_fuse: Enable module fusion
fuse_max_seq_len: Maximum sequence length for fusion
modules_to_fuse: Specific modules to fuse
"""Optimize training speed and memory usage with mixed precision techniques.
class TrainingArguments:
"""Training arguments with mixed precision options."""
def __init__(
self,
# Mixed precision options
fp16: bool = False,
bf16: bool = False,
fp16_opt_level: str = "O1",
fp16_backend: str = "auto",
fp16_full_eval: bool = False,
bf16_full_eval: bool = False,
tf32: Optional[bool] = None,
dataloader_pin_memory: bool = True,
# Other training args...
**kwargs
):
"""
Configure mixed precision training.
Args:
fp16: Enable 16-bit floating point training
bf16: Enable bfloat16 training (better numerical stability)
fp16_opt_level: Optimization level for Apex ("O0", "O1", "O2", "O3")
fp16_backend: Backend for fp16 ("auto", "apex", "amp")
fp16_full_eval: Use fp16 for evaluation
bf16_full_eval: Use bf16 for evaluation
tf32: Enable TensorFloat-32 on Ampere GPUs
dataloader_pin_memory: Pin memory for faster data transfer
"""Trade computation for memory by recomputing activations during backward pass.
class PreTrainedModel:
"""Model with gradient checkpointing support."""
def gradient_checkpointing_enable(
self,
gradient_checkpointing_kwargs: Optional[Dict] = None
) -> None:
"""
Enable gradient checkpointing for the model.
Args:
gradient_checkpointing_kwargs: Additional arguments for checkpointing
"""
def gradient_checkpointing_disable(self) -> None:
"""Disable gradient checkpointing."""
class TrainingArguments:
def __init__(
self,
gradient_checkpointing: bool = False,
gradient_checkpointing_kwargs: Optional[Dict] = None,
**kwargs
):
"""
Configure gradient checkpointing in training.
Args:
gradient_checkpointing: Enable gradient checkpointing
gradient_checkpointing_kwargs: Additional checkpointing options
"""Advanced memory management techniques for large models.
def enable_memory_efficient_attention():
"""Enable memory-efficient attention implementations."""
def get_memory_footprint_mb(
model: torch.nn.Module,
return_buffers: bool = True
) -> int:
"""
Get model memory footprint in MB.
Args:
model: PyTorch model
return_buffers: Include buffer memory
Returns:
Memory footprint in megabytes
"""
class DeepSpeedConfig:
"""Configuration for DeepSpeed optimization."""
@staticmethod
def get_config(
stage: int = 2,
offload_optimizer: bool = False,
offload_param: bool = False,
reduce_bucket_size: int = 200000000,
stage3_prefetch_bucket_size: int = 200000000,
stage3_param_persistence_threshold: int = 1000000,
**kwargs
) -> Dict[str, Any]:
"""
Get DeepSpeed configuration dictionary.
Args:
stage: ZeRO stage (1, 2, or 3)
offload_optimizer: Offload optimizer states to CPU
offload_param: Offload parameters to CPU
reduce_bucket_size: Gradient reduction bucket size
stage3_prefetch_bucket_size: Parameter prefetch bucket size
stage3_param_persistence_threshold: Parameter persistence threshold
Returns:
DeepSpeed configuration dictionary
"""Optimize for specific hardware platforms and accelerators.
class BetterTransformer:
"""Flash Attention and other optimized kernels."""
@staticmethod
def transform(
model: PreTrainedModel,
keep_original_model: bool = False,
**kwargs
) -> PreTrainedModel:
"""
Apply BetterTransformer optimizations.
Args:
model: Model to optimize
keep_original_model: Keep reference to original model
Returns:
Optimized model with fast attention
"""
@staticmethod
def reverse(model: PreTrainedModel) -> PreTrainedModel:
"""Reverse BetterTransformer optimizations."""
# PyTorch 2.0 Compilation
def torch_compile_model(
model: PreTrainedModel,
backend: str = "inductor",
mode: str = "default",
**kwargs
) -> PreTrainedModel:
"""
Compile model with PyTorch 2.0 torch.compile.
Args:
model: Model to compile
backend: Compilation backend ("inductor", "aot_eager", etc.)
mode: Compilation mode ("default", "reduce-overhead", "max-autotune")
Returns:
Compiled model
"""
class TrainingArguments:
def __init__(
self,
torch_compile: bool = False,
torch_compile_backend: Optional[str] = None,
torch_compile_mode: Optional[str] = None,
**kwargs
):
"""
Configure PyTorch compilation in training.
Args:
torch_compile: Enable torch.compile
torch_compile_backend: Compilation backend
torch_compile_mode: Compilation mode
"""Distribute large models across multiple devices.
class TrainingArguments:
def __init__(
self,
# Data parallelism
local_rank: int = -1,
ddp_backend: Optional[str] = None,
ddp_timeout: Optional[int] = 1800,
ddp_find_unused_parameters: Optional[bool] = None,
# Model parallelism
fsdp: str = "",
fsdp_min_num_params: int = 0,
fsdp_config: Optional[str] = None,
fsdp_transformer_layer_cls_to_wrap: Optional[str] = None,
# Pipeline parallelism
deepspeed: Optional[str] = None,
**kwargs
):
"""
Configure distributed training strategies.
Args:
local_rank: Local rank for distributed training
ddp_backend: Distributed data parallel backend
ddp_timeout: DDP timeout in seconds
fsdp: Fully Sharded Data Parallel configuration
fsdp_min_num_params: Minimum parameters for FSDP wrapping
deepspeed: DeepSpeed configuration file path
"""
def load_model_with_device_map(
model_name: str,
device_map: Union[str, Dict] = "auto",
max_memory: Optional[Dict] = None,
offload_folder: Optional[str] = None,
**kwargs
) -> PreTrainedModel:
"""
Load model with automatic device mapping.
Args:
model_name: Model name or path
device_map: Device mapping strategy or custom mapping
max_memory: Maximum memory per device
offload_folder: Folder for offloaded weights
Returns:
Model distributed across available devices
"""Optimize models specifically for inference workloads.
def optimize_model_for_inference(
model: PreTrainedModel,
optimize_for_latency: bool = True,
optimize_for_throughput: bool = False,
use_bettertransformer: bool = True,
use_torch_compile: bool = True,
**kwargs
) -> PreTrainedModel:
"""
Apply inference-specific optimizations.
Args:
model: Model to optimize
optimize_for_latency: Optimize for low latency
optimize_for_throughput: Optimize for high throughput
use_bettertransformer: Apply BetterTransformer
use_torch_compile: Use torch.compile
Returns:
Optimized model for inference
"""
class StaticCache:
"""Static key-value cache for improved inference performance."""
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: int,
device: torch.device,
dtype: torch.dtype = torch.float16
):
"""
Initialize static cache.
Args:
config: Model configuration
max_batch_size: Maximum batch size
max_cache_len: Maximum cache length
device: Device for cache tensors
dtype: Data type for cache
"""Optimize model loading and sharing through intelligent caching.
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
cache_dir: Union[str, os.PathLike] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Union[bool, str] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs
) -> Optional[str]:
"""
Download and cache file from Hugging Face Hub.
Args:
path_or_repo_id: Repository ID or local path
filename: File to download
cache_dir: Custom cache directory
force_download: Force fresh download
resume_download: Resume interrupted download
token: Authentication token
revision: Model revision/branch
local_files_only: Only use local files
Returns:
Path to cached file
"""
def clean_files_cache(
cache_dir: Optional[Union[str, os.PathLike]] = None,
token: Union[bool, str] = None
) -> None:
"""Clean up cached files to free disk space."""
def scan_cache_dir(
cache_dir: Union[str, os.PathLike] = None
) -> Dict[str, Any]:
"""Scan cache directory and return usage statistics."""Common optimization patterns for different use cases:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# 8-bit quantization
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-6.7b",
quantization_config=quantization_config,
device_map="auto"
)
# 4-bit quantization with NF4
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/DialoGPT-large",
quantization_config=quantization_config
)
# Mixed precision training
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
bf16=True, # Use bfloat16 for better stability
gradient_checkpointing=True, # Save memory
dataloader_pin_memory=True, # Faster data loading
torch_compile=True, # PyTorch 2.0 compilation
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
# BetterTransformer optimization
from transformers import BetterTransformer
# Apply optimizations
model = BetterTransformer.transform(model)
# Use optimized model for inference
outputs = model.generate(**inputs, max_new_tokens=50)
# DeepSpeed training
training_args = TrainingArguments(
output_dir="./results",
deepspeed="ds_config.json", # DeepSpeed config file
gradient_checkpointing=True,
fp16=True
)
# Device mapping for large models
model = AutoModelForCausalLM.from_pretrained(
"microsoft/DialoGPT-large",
device_map="auto", # Automatic device placement
max_memory={0: "10GB", 1: "10GB"}, # Memory limits per GPU
offload_folder="./offload" # Offload unused weights
)
# Static cache for faster inference
from transformers import StaticCache
cache = StaticCache(
config=model.config,
max_batch_size=4,
max_cache_len=512,
device=model.device,
dtype=torch.float16
)
# Use cache during generation
outputs = model.generate(
**inputs,
past_key_values=cache,
use_cache=True,
max_new_tokens=50
)For Training:
bf16=True for better numerical stability than fp16gradient_checkpointing=True for large modelstorch_compile=True with PyTorch 2.0+ for speed improvementsFor Inference:
torch_compile=True for repeated inference patternsFor Memory Optimization:
device_map="auto" for automatic multi-GPU placementmax_memory limits per deviceInstall with Tessl CLI
npx tessl i tessl/pypi-transformers