0
# Utilities
1
2
File handling, caching, and model loading utilities providing automatic download and caching of pre-trained models, conversion from TensorFlow checkpoints, and various file system operations for managing model assets.
3
4
## Capabilities
5
6
### File Caching and Download
7
8
Utilities for automatically downloading, caching, and managing pre-trained model files with support for URLs, local paths, and cloud storage.
9
10
```python { .api }
11
def cached_path(url_or_filename, cache_dir=None):
12
"""
13
Download and cache files from URLs or return local file paths.
14
15
Given a URL or local file path, this function downloads the file (if it's a URL)
16
to a local cache directory and returns the path to the cached file.
17
18
Args:
19
url_or_filename (str): URL to download or local file path
20
cache_dir (str, optional): Directory to cache files. If None, uses default cache directory
21
22
Returns:
23
str: Path to the cached or local file
24
25
Raises:
26
EnvironmentError: If file cannot be found or downloaded
27
"""
28
29
```
30
31
32
33
### Constants
34
35
Standard filenames and cache directory configuration for model management.
36
37
```python { .api }
38
PYTORCH_PRETRAINED_BERT_CACHE = "~/.pytorch_pretrained_bert"
39
```
40
Default cache directory for storing downloaded model files and checkpoints.
41
42
```python { .api }
43
CONFIG_NAME = "config.json"
44
```
45
Standard filename for model configuration files.
46
47
```python { .api }
48
WEIGHTS_NAME = "pytorch_model.bin"
49
```
50
Standard filename for PyTorch model weight files.
51
52
## TensorFlow Weight Conversion
53
54
Functions to convert TensorFlow checkpoints to PyTorch format for all supported model architectures.
55
56
### BERT Weight Conversion
57
58
```python { .api }
59
def load_tf_weights_in_bert(model, tf_checkpoint_path):
60
"""
61
Load TensorFlow BERT checkpoint weights into PyTorch BERT model.
62
63
Args:
64
model: PyTorch BERT model instance (any BERT variant)
65
tf_checkpoint_path (str): Path to TensorFlow checkpoint file
66
67
Returns:
68
PyTorch model with loaded TensorFlow weights
69
70
Raises:
71
ValueError: If checkpoint format is incompatible
72
FileNotFoundError: If checkpoint file doesn't exist
73
"""
74
```
75
76
### OpenAI GPT Weight Conversion
77
78
```python { .api }
79
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
80
"""
81
Load TensorFlow OpenAI GPT checkpoint into PyTorch model.
82
83
Args:
84
model: PyTorch OpenAI GPT model instance
85
openai_checkpoint_folder_path (str): Path to folder containing TF checkpoint files
86
87
Returns:
88
PyTorch model with loaded TensorFlow weights
89
90
Raises:
91
ValueError: If checkpoint format is incompatible
92
FileNotFoundError: If checkpoint files don't exist
93
"""
94
```
95
96
### GPT-2 Weight Conversion
97
98
```python { .api }
99
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
100
"""
101
Load TensorFlow GPT-2 checkpoint into PyTorch model.
102
103
Args:
104
model: PyTorch GPT-2 model instance
105
gpt2_checkpoint_path (str): Path to TensorFlow GPT-2 checkpoint
106
107
Returns:
108
PyTorch model with loaded TensorFlow weights
109
110
Raises:
111
ValueError: If checkpoint format is incompatible
112
FileNotFoundError: If checkpoint file doesn't exist
113
"""
114
```
115
116
### Transformer-XL Weight Conversion
117
118
```python { .api }
119
def load_tf_weights_in_transfo_xl(model, config, tf_path):
120
"""
121
Load TensorFlow Transformer-XL checkpoint into PyTorch model.
122
123
Args:
124
model: PyTorch Transformer-XL model instance
125
config: TransfoXLConfig instance
126
tf_path (str): Path to TensorFlow checkpoint
127
128
Returns:
129
PyTorch model with loaded TensorFlow weights
130
131
Raises:
132
ValueError: If checkpoint format is incompatible
133
FileNotFoundError: If checkpoint file doesn't exist
134
"""
135
```
136
137
138
## Usage Examples
139
140
### Basic File Caching
141
142
```python
143
from pytorch_pretrained_bert import cached_path
144
145
# Download and cache from URL
146
model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"
147
local_path = cached_path(model_url)
148
print(f"Model cached at: {local_path}")
149
150
# Use local file path (returns as-is)
151
local_file = "/path/to/local/model.bin"
152
path = cached_path(local_file)
153
print(f"Local file: {path}")
154
155
# Use custom cache directory
156
custom_cache = cached_path(model_url, cache_dir="./my_models/")
157
print(f"Cached in custom directory: {custom_cache}")
158
```
159
160
### Converting TensorFlow Models
161
162
```python
163
from pytorch_pretrained_bert import (
164
BertModel, BertConfig, load_tf_weights_in_bert,
165
OpenAIGPTModel, OpenAIGPTConfig, load_tf_weights_in_openai_gpt
166
)
167
168
# Convert BERT from TensorFlow
169
bert_config = BertConfig.from_json_file("bert_config.json")
170
bert_model = BertModel(bert_config)
171
load_tf_weights_in_bert(bert_model, "bert_model.ckpt")
172
173
# Save converted model
174
torch.save(bert_model.state_dict(), "pytorch_bert_model.bin")
175
176
# Convert OpenAI GPT from TensorFlow
177
gpt_config = OpenAIGPTConfig.from_json_file("openai_gpt_config.json")
178
gpt_model = OpenAIGPTModel(gpt_config)
179
load_tf_weights_in_openai_gpt(gpt_model, "./openai_gpt_checkpoint/")
180
181
# Save converted model
182
torch.save(gpt_model.state_dict(), "pytorch_openai_gpt.bin")
183
```
184
185
### Custom Cache Management
186
187
```python
188
from pytorch_pretrained_bert import (
189
PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
190
)
191
import os
192
193
# Check default cache directory
194
cache_dir = os.path.expanduser(PYTORCH_PRETRAINED_BERT_CACHE)
195
print(f"Default cache: {cache_dir}")
196
197
# Check for standard model files
198
model_weights = os.path.join(cache_dir, WEIGHTS_NAME)
199
model_config = os.path.join(cache_dir, CONFIG_NAME)
200
print(f"Expected weights: {model_weights}")
201
print(f"Expected config: {model_config}")
202
```
203
204
205
206
207
### Error Handling
208
209
```python
210
from pytorch_pretrained_bert import cached_path
211
import os
212
213
def safe_download(url_or_path, cache_dir=None):
214
"""Safely download or access file with error handling."""
215
try:
216
path = cached_path(url_or_path, cache_dir=cache_dir)
217
if os.path.exists(path):
218
size = os.path.getsize(path)
219
print(f"Successfully accessed: {path} ({size} bytes)")
220
return path
221
else:
222
print(f"File not found: {path}")
223
return None
224
except EnvironmentError as e:
225
print(f"Error accessing {url_or_path}: {e}")
226
return None
227
except Exception as e:
228
print(f"Unexpected error: {e}")
229
return None
230
231
# Test with various inputs
232
test_files = [
233
"https://invalid-url.com/model.bin", # Invalid URL
234
"/nonexistent/path/model.bin", # Nonexistent local path
235
"https://github.com/", # Valid URL, invalid model
236
]
237
238
for test_file in test_files:
239
result = safe_download(test_file)
240
print(f"Result for {test_file}: {result}\n")
241
```