PyTorch implementations of transformer-based language models including BERT, OpenAI GPT, GPT-2, and Transformer-XL with pre-trained models, tokenizers, and utilities for NLP tasks
—
File handling, caching, and model loading utilities providing automatic download and caching of pre-trained models, conversion from TensorFlow checkpoints, and various file system operations for managing model assets.
Utilities for automatically downloading, caching, and managing pre-trained model files with support for URLs, local paths, and cloud storage.
def cached_path(url_or_filename, cache_dir=None):
"""
Download and cache files from URLs or return local file paths.
Given a URL or local file path, this function downloads the file (if it's a URL)
to a local cache directory and returns the path to the cached file.
Args:
url_or_filename (str): URL to download or local file path
cache_dir (str, optional): Directory to cache files. If None, uses default cache directory
Returns:
str: Path to the cached or local file
Raises:
EnvironmentError: If file cannot be found or downloaded
"""Standard filenames and cache directory configuration for model management.
PYTORCH_PRETRAINED_BERT_CACHE = "~/.pytorch_pretrained_bert"Default cache directory for storing downloaded model files and checkpoints.
CONFIG_NAME = "config.json"Standard filename for model configuration files.
WEIGHTS_NAME = "pytorch_model.bin"Standard filename for PyTorch model weight files.
Functions to convert TensorFlow checkpoints to PyTorch format for all supported model architectures.
def load_tf_weights_in_bert(model, tf_checkpoint_path):
"""
Load TensorFlow BERT checkpoint weights into PyTorch BERT model.
Args:
model: PyTorch BERT model instance (any BERT variant)
tf_checkpoint_path (str): Path to TensorFlow checkpoint file
Returns:
PyTorch model with loaded TensorFlow weights
Raises:
ValueError: If checkpoint format is incompatible
FileNotFoundError: If checkpoint file doesn't exist
"""def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
"""
Load TensorFlow OpenAI GPT checkpoint into PyTorch model.
Args:
model: PyTorch OpenAI GPT model instance
openai_checkpoint_folder_path (str): Path to folder containing TF checkpoint files
Returns:
PyTorch model with loaded TensorFlow weights
Raises:
ValueError: If checkpoint format is incompatible
FileNotFoundError: If checkpoint files don't exist
"""def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
"""
Load TensorFlow GPT-2 checkpoint into PyTorch model.
Args:
model: PyTorch GPT-2 model instance
gpt2_checkpoint_path (str): Path to TensorFlow GPT-2 checkpoint
Returns:
PyTorch model with loaded TensorFlow weights
Raises:
ValueError: If checkpoint format is incompatible
FileNotFoundError: If checkpoint file doesn't exist
"""def load_tf_weights_in_transfo_xl(model, config, tf_path):
"""
Load TensorFlow Transformer-XL checkpoint into PyTorch model.
Args:
model: PyTorch Transformer-XL model instance
config: TransfoXLConfig instance
tf_path (str): Path to TensorFlow checkpoint
Returns:
PyTorch model with loaded TensorFlow weights
Raises:
ValueError: If checkpoint format is incompatible
FileNotFoundError: If checkpoint file doesn't exist
"""from pytorch_pretrained_bert import cached_path
# Download and cache from URL
model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"
local_path = cached_path(model_url)
print(f"Model cached at: {local_path}")
# Use local file path (returns as-is)
local_file = "/path/to/local/model.bin"
path = cached_path(local_file)
print(f"Local file: {path}")
# Use custom cache directory
custom_cache = cached_path(model_url, cache_dir="./my_models/")
print(f"Cached in custom directory: {custom_cache}")from pytorch_pretrained_bert import (
BertModel, BertConfig, load_tf_weights_in_bert,
OpenAIGPTModel, OpenAIGPTConfig, load_tf_weights_in_openai_gpt
)
# Convert BERT from TensorFlow
bert_config = BertConfig.from_json_file("bert_config.json")
bert_model = BertModel(bert_config)
load_tf_weights_in_bert(bert_model, "bert_model.ckpt")
# Save converted model
torch.save(bert_model.state_dict(), "pytorch_bert_model.bin")
# Convert OpenAI GPT from TensorFlow
gpt_config = OpenAIGPTConfig.from_json_file("openai_gpt_config.json")
gpt_model = OpenAIGPTModel(gpt_config)
load_tf_weights_in_openai_gpt(gpt_model, "./openai_gpt_checkpoint/")
# Save converted model
torch.save(gpt_model.state_dict(), "pytorch_openai_gpt.bin")from pytorch_pretrained_bert import (
PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
)
import os
# Check default cache directory
cache_dir = os.path.expanduser(PYTORCH_PRETRAINED_BERT_CACHE)
print(f"Default cache: {cache_dir}")
# Check for standard model files
model_weights = os.path.join(cache_dir, WEIGHTS_NAME)
model_config = os.path.join(cache_dir, CONFIG_NAME)
print(f"Expected weights: {model_weights}")
print(f"Expected config: {model_config}")from pytorch_pretrained_bert import cached_path
import os
def safe_download(url_or_path, cache_dir=None):
"""Safely download or access file with error handling."""
try:
path = cached_path(url_or_path, cache_dir=cache_dir)
if os.path.exists(path):
size = os.path.getsize(path)
print(f"Successfully accessed: {path} ({size} bytes)")
return path
else:
print(f"File not found: {path}")
return None
except EnvironmentError as e:
print(f"Error accessing {url_or_path}: {e}")
return None
except Exception as e:
print(f"Unexpected error: {e}")
return None
# Test with various inputs
test_files = [
"https://invalid-url.com/model.bin", # Invalid URL
"/nonexistent/path/model.bin", # Nonexistent local path
"https://github.com/", # Valid URL, invalid model
]
for test_file in test_files:
result = safe_download(test_file)
print(f"Result for {test_file}: {result}\n")Install with Tessl CLI
npx tessl i tessl/pypi-pytorch-pretrained-bert