State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Comprehensive model management with automatic selection and loading for 350+ architectures. The model system provides consistent APIs across text, vision, audio, and multimodal domains while supporting PyTorch, TensorFlow, and JAX frameworks.
Automatic model selection based on model names or configurations, eliminating the need to know specific architecture classes.
class AutoModel:
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
*model_args,
config: PretrainedConfig = None,
cache_dir: Union[str, os.PathLike] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Union[bool, str] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs
) -> PreTrainedModel:
"""
Load a pretrained model automatically detecting the architecture.
Args:
pretrained_model_name_or_path: Model name or local path
config: Model configuration (auto-detected if None)
cache_dir: Custom cache directory
ignore_mismatched_sizes: Ignore size mismatches when loading
force_download: Force fresh download
local_files_only: Only use local files
token: Hugging Face authentication token
revision: Model revision/branch
use_safetensors: Use safetensors format when available
Returns:
Loaded model instance
"""Pre-configured models for common tasks with appropriate heads and loss functions.
class AutoModelForSequenceClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for sequence classification tasks."""
class AutoModelForTokenClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for token classification (NER, POS tagging)."""
class AutoModelForQuestionAnswering:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for extractive question answering."""
class AutoModelForMaskedLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for masked language modeling."""
class AutoModelForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for causal language modeling (text generation)."""
class AutoModelForSeq2SeqLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for sequence-to-sequence tasks."""
class AutoModelForImageClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for image classification."""
class AutoModelForObjectDetection:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> PreTrainedModel:
"""Load model for object detection."""Usage examples:
# Text classification
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=3
)
# Text generation
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Image classification
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")Foundation classes that all specific model implementations inherit from.
class PreTrainedModel:
"""Base class for all PyTorch models."""
def __init__(self, config: PretrainedConfig, *inputs, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs
) -> 'PreTrainedModel':
"""Load pretrained model weights and configuration."""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Dict[str, torch.Tensor] = None,
save_function: Callable = None,
push_to_hub: bool = False,
max_shard_size: Union[int, str] = "5GB",
safe_serialization: bool = True,
**kwargs
) -> None:
"""Save model weights and configuration."""
def push_to_hub(
self,
repo_id: str,
use_temp_dir: bool = None,
commit_message: str = None,
private: bool = None,
token: Union[bool, str] = None,
**kwargs
) -> str:
"""Upload model to Hugging Face Hub."""
def forward(self, **kwargs) -> Union[torch.Tensor, ModelOutput]:
"""Forward pass through the model."""
def generate(self, **kwargs) -> torch.Tensor:
"""Generate sequences (available on generative models)."""
def resize_token_embeddings(
self,
new_num_tokens: int = None
) -> torch.nn.Embedding:
"""Resize input token embeddings matrix."""
def get_input_embeddings(self) -> torch.nn.Module:
"""Get input embeddings layer."""
def set_input_embeddings(self, value: torch.nn.Module) -> None:
"""Set input embeddings layer."""
def tie_weights(self) -> None:
"""Tie input and output embeddings if specified in config."""
def gradient_checkpointing_enable(self) -> None:
"""Enable gradient checkpointing for training."""
def gradient_checkpointing_disable(self) -> None:
"""Disable gradient checkpointing."""TensorFlow implementations of all model architectures with Keras compatibility.
class TFPreTrainedModel:
"""Base class for all TensorFlow models."""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs
) -> 'TFPreTrainedModel':
"""Load pretrained TensorFlow model."""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
**kwargs
) -> None:
"""Save TensorFlow model."""
def call(self, **kwargs) -> Union[tf.Tensor, TFModelOutput]:
"""Forward pass through TensorFlow model."""
# Task-specific TF models
class TFAutoModel:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModel
class TFAutoModelForSequenceClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModel
class TFAutoModelForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> TFPreTrainedModelJAX implementations with Flax for high-performance training and inference.
class FlaxPreTrainedModel:
"""Base class for all Flax/JAX models."""
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs
) -> 'FlaxPreTrainedModel':
"""Load pretrained Flax model."""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
**kwargs
) -> None:
"""Save Flax model."""
def __call__(self, **kwargs) -> Union[jnp.ndarray, FlaxModelOutput]:
"""Forward pass through Flax model."""
# Task-specific Flax models
class FlaxAutoModel:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> FlaxPreTrainedModel
class FlaxAutoModelForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> FlaxPreTrainedModelConfiguration classes that define model architectures and hyperparameters.
class AutoConfig:
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs
) -> PretrainedConfig:
"""Load model configuration automatically."""
class PretrainedConfig:
"""Base configuration class for all models."""
def __init__(self, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs
) -> 'PretrainedConfig':
"""Load configuration from pretrained model."""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs
) -> None:
"""Save configuration to directory."""
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
def to_json_file(self, json_file_path: Union[str, os.PathLike]) -> None:
"""Save configuration to JSON file."""class BertModel(PreTrainedModel):
"""BERT model for encoding tasks."""
class BertForSequenceClassification(PreTrainedModel):
"""BERT model with sequence classification head."""
class BertForTokenClassification(PreTrainedModel):
"""BERT model with token classification head."""
class BertForQuestionAnswering(PreTrainedModel):
"""BERT model with question answering head."""
class BertForMaskedLM(PreTrainedModel):
"""BERT model with masked language modeling head."""class GPT2Model(PreTrainedModel):
"""GPT-2 model for generation tasks."""
class GPT2LMHeadModel(PreTrainedModel):
"""GPT-2 model with language modeling head."""
class GPTNeoModel(PreTrainedModel):
"""GPT-Neo model architecture."""
class GPTNeoXModel(PreTrainedModel):
"""GPT-NeoX model architecture."""
class GPTJModel(PreTrainedModel):
"""GPT-J model architecture."""class T5Model(PreTrainedModel):
"""T5 encoder-decoder model."""
class T5ForConditionalGeneration(PreTrainedModel):
"""T5 model with conditional generation head."""
class T5EncoderModel(PreTrainedModel):
"""T5 encoder-only model."""class ViTModel(PreTrainedModel):
"""Vision Transformer model."""
class ViTForImageClassification(PreTrainedModel):
"""ViT model with image classification head."""
class DetrModel(PreTrainedModel):
"""DETR object detection model."""
class DetrForObjectDetection(PreTrainedModel):
"""DETR model with object detection head."""class CLIPModel(PreTrainedModel):
"""CLIP vision-language model."""
class CLIPTextModel(PreTrainedModel):
"""CLIP text encoder."""
class CLIPVisionModel(PreTrainedModel):
"""CLIP vision encoder."""
class BlipModel(PreTrainedModel):
"""BLIP multimodal model."""
class BlipForConditionalGeneration(PreTrainedModel):
"""BLIP model with conditional generation."""Standard output formats for different model types:
class BaseModelOutput:
"""Base output type for encoder models."""
last_hidden_state: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
class BaseModelOutputWithPooling:
"""Base output with pooling for classification models."""
last_hidden_state: torch.Tensor
pooler_output: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
class CausalLMOutput:
"""Output for causal language models."""
loss: Optional[torch.Tensor] = None
logits: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
class SequenceClassifierOutput:
"""Output for sequence classification models."""
loss: Optional[torch.Tensor] = None
logits: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
class TokenClassifierOutput:
"""Output for token classification models."""
loss: Optional[torch.Tensor] = None
logits: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
class QuestionAnsweringModelOutput:
"""Output for question answering models."""
loss: Optional[torch.Tensor] = None
start_logits: torch.Tensor
end_logits: torch.Tensor
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = NoneCommon patterns for working with models:
# Load model with custom configuration
config = AutoConfig.from_pretrained("bert-base-uncased")
config.num_labels = 3
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased",
config=config
)
# Load model with custom dtype and device
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
torch_dtype=torch.float16,
device_map="auto"
)
# Save model locally
model.save_pretrained("./my-model")
# Upload to Hub
model.push_to_hub("username/my-model", private=True)
# Load from local directory
model = AutoModel.from_pretrained("./my-model")Install with Tessl CLI
npx tessl i tessl/pypi-transformers