CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-pytorch-pretrained-bert

PyTorch implementations of transformer-based language models including BERT, OpenAI GPT, GPT-2, and Transformer-XL with pre-trained models, tokenizers, and utilities for NLP tasks

Pending
Overview
Eval results
Files

utilities.mddocs/

Utilities

File handling, caching, and model loading utilities providing automatic download and caching of pre-trained models, conversion from TensorFlow checkpoints, and various file system operations for managing model assets.

Capabilities

File Caching and Download

Utilities for automatically downloading, caching, and managing pre-trained model files with support for URLs, local paths, and cloud storage.

def cached_path(url_or_filename, cache_dir=None):
    """
    Download and cache files from URLs or return local file paths.
    
    Given a URL or local file path, this function downloads the file (if it's a URL)
    to a local cache directory and returns the path to the cached file.
    
    Args:
        url_or_filename (str): URL to download or local file path
        cache_dir (str, optional): Directory to cache files. If None, uses default cache directory
        
    Returns:
        str: Path to the cached or local file
        
    Raises:
        EnvironmentError: If file cannot be found or downloaded
    """

Constants

Standard filenames and cache directory configuration for model management.

PYTORCH_PRETRAINED_BERT_CACHE = "~/.pytorch_pretrained_bert"

Default cache directory for storing downloaded model files and checkpoints.

CONFIG_NAME = "config.json"

Standard filename for model configuration files.

WEIGHTS_NAME = "pytorch_model.bin"

Standard filename for PyTorch model weight files.

TensorFlow Weight Conversion

Functions to convert TensorFlow checkpoints to PyTorch format for all supported model architectures.

BERT Weight Conversion

def load_tf_weights_in_bert(model, tf_checkpoint_path):
    """
    Load TensorFlow BERT checkpoint weights into PyTorch BERT model.
    
    Args:
        model: PyTorch BERT model instance (any BERT variant)
        tf_checkpoint_path (str): Path to TensorFlow checkpoint file
        
    Returns:
        PyTorch model with loaded TensorFlow weights
        
    Raises:
        ValueError: If checkpoint format is incompatible
        FileNotFoundError: If checkpoint file doesn't exist
    """

OpenAI GPT Weight Conversion

def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
    """
    Load TensorFlow OpenAI GPT checkpoint into PyTorch model.
    
    Args:
        model: PyTorch OpenAI GPT model instance
        openai_checkpoint_folder_path (str): Path to folder containing TF checkpoint files
        
    Returns:
        PyTorch model with loaded TensorFlow weights
        
    Raises:
        ValueError: If checkpoint format is incompatible
        FileNotFoundError: If checkpoint files don't exist
    """

GPT-2 Weight Conversion

def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
    """
    Load TensorFlow GPT-2 checkpoint into PyTorch model.
    
    Args:
        model: PyTorch GPT-2 model instance
        gpt2_checkpoint_path (str): Path to TensorFlow GPT-2 checkpoint
        
    Returns:
        PyTorch model with loaded TensorFlow weights
        
    Raises:
        ValueError: If checkpoint format is incompatible
        FileNotFoundError: If checkpoint file doesn't exist
    """

Transformer-XL Weight Conversion

def load_tf_weights_in_transfo_xl(model, config, tf_path):
    """
    Load TensorFlow Transformer-XL checkpoint into PyTorch model.
    
    Args:
        model: PyTorch Transformer-XL model instance
        config: TransfoXLConfig instance
        tf_path (str): Path to TensorFlow checkpoint
        
    Returns:
        PyTorch model with loaded TensorFlow weights
        
    Raises:
        ValueError: If checkpoint format is incompatible
        FileNotFoundError: If checkpoint file doesn't exist
    """

Usage Examples

Basic File Caching

from pytorch_pretrained_bert import cached_path

# Download and cache from URL
model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"
local_path = cached_path(model_url)
print(f"Model cached at: {local_path}")

# Use local file path (returns as-is)
local_file = "/path/to/local/model.bin"
path = cached_path(local_file)
print(f"Local file: {path}")

# Use custom cache directory
custom_cache = cached_path(model_url, cache_dir="./my_models/")
print(f"Cached in custom directory: {custom_cache}")

Converting TensorFlow Models

from pytorch_pretrained_bert import (
    BertModel, BertConfig, load_tf_weights_in_bert,
    OpenAIGPTModel, OpenAIGPTConfig, load_tf_weights_in_openai_gpt
)

# Convert BERT from TensorFlow
bert_config = BertConfig.from_json_file("bert_config.json")
bert_model = BertModel(bert_config)
load_tf_weights_in_bert(bert_model, "bert_model.ckpt")

# Save converted model
torch.save(bert_model.state_dict(), "pytorch_bert_model.bin")

# Convert OpenAI GPT from TensorFlow
gpt_config = OpenAIGPTConfig.from_json_file("openai_gpt_config.json")
gpt_model = OpenAIGPTModel(gpt_config)
load_tf_weights_in_openai_gpt(gpt_model, "./openai_gpt_checkpoint/")

# Save converted model
torch.save(gpt_model.state_dict(), "pytorch_openai_gpt.bin")

Custom Cache Management

from pytorch_pretrained_bert import (
    PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
)
import os

# Check default cache directory
cache_dir = os.path.expanduser(PYTORCH_PRETRAINED_BERT_CACHE)
print(f"Default cache: {cache_dir}")

# Check for standard model files
model_weights = os.path.join(cache_dir, WEIGHTS_NAME)
model_config = os.path.join(cache_dir, CONFIG_NAME)
print(f"Expected weights: {model_weights}")
print(f"Expected config: {model_config}")

Error Handling

from pytorch_pretrained_bert import cached_path
import os

def safe_download(url_or_path, cache_dir=None):
    """Safely download or access file with error handling."""
    try:
        path = cached_path(url_or_path, cache_dir=cache_dir)
        if os.path.exists(path):
            size = os.path.getsize(path)
            print(f"Successfully accessed: {path} ({size} bytes)")
            return path
        else:
            print(f"File not found: {path}")
            return None
    except EnvironmentError as e:
        print(f"Error accessing {url_or_path}: {e}")
        return None
    except Exception as e:
        print(f"Unexpected error: {e}")
        return None

# Test with various inputs
test_files = [
    "https://invalid-url.com/model.bin",  # Invalid URL
    "/nonexistent/path/model.bin",        # Nonexistent local path
    "https://github.com/",                # Valid URL, invalid model
]

for test_file in test_files:
    result = safe_download(test_file)
    print(f"Result for {test_file}: {result}\n")

Install with Tessl CLI

npx tessl i tessl/pypi-pytorch-pretrained-bert

docs

bert-models.md

gpt-models.md

index.md

optimizers.md

tokenizers.md

utilities.md

tile.json