or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

bert-models.mdgpt-models.mdindex.mdoptimizers.mdtokenizers.mdutilities.md

utilities.mddocs/

0

# Utilities

1

2

File handling, caching, and model loading utilities providing automatic download and caching of pre-trained models, conversion from TensorFlow checkpoints, and various file system operations for managing model assets.

3

4

## Capabilities

5

6

### File Caching and Download

7

8

Utilities for automatically downloading, caching, and managing pre-trained model files with support for URLs, local paths, and cloud storage.

9

10

```python { .api }

11

def cached_path(url_or_filename, cache_dir=None):

12

"""

13

Download and cache files from URLs or return local file paths.

14

15

Given a URL or local file path, this function downloads the file (if it's a URL)

16

to a local cache directory and returns the path to the cached file.

17

18

Args:

19

url_or_filename (str): URL to download or local file path

20

cache_dir (str, optional): Directory to cache files. If None, uses default cache directory

21

22

Returns:

23

str: Path to the cached or local file

24

25

Raises:

26

EnvironmentError: If file cannot be found or downloaded

27

"""

28

29

```

30

31

32

33

### Constants

34

35

Standard filenames and cache directory configuration for model management.

36

37

```python { .api }

38

PYTORCH_PRETRAINED_BERT_CACHE = "~/.pytorch_pretrained_bert"

39

```

40

Default cache directory for storing downloaded model files and checkpoints.

41

42

```python { .api }

43

CONFIG_NAME = "config.json"

44

```

45

Standard filename for model configuration files.

46

47

```python { .api }

48

WEIGHTS_NAME = "pytorch_model.bin"

49

```

50

Standard filename for PyTorch model weight files.

51

52

## TensorFlow Weight Conversion

53

54

Functions to convert TensorFlow checkpoints to PyTorch format for all supported model architectures.

55

56

### BERT Weight Conversion

57

58

```python { .api }

59

def load_tf_weights_in_bert(model, tf_checkpoint_path):

60

"""

61

Load TensorFlow BERT checkpoint weights into PyTorch BERT model.

62

63

Args:

64

model: PyTorch BERT model instance (any BERT variant)

65

tf_checkpoint_path (str): Path to TensorFlow checkpoint file

66

67

Returns:

68

PyTorch model with loaded TensorFlow weights

69

70

Raises:

71

ValueError: If checkpoint format is incompatible

72

FileNotFoundError: If checkpoint file doesn't exist

73

"""

74

```

75

76

### OpenAI GPT Weight Conversion

77

78

```python { .api }

79

def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):

80

"""

81

Load TensorFlow OpenAI GPT checkpoint into PyTorch model.

82

83

Args:

84

model: PyTorch OpenAI GPT model instance

85

openai_checkpoint_folder_path (str): Path to folder containing TF checkpoint files

86

87

Returns:

88

PyTorch model with loaded TensorFlow weights

89

90

Raises:

91

ValueError: If checkpoint format is incompatible

92

FileNotFoundError: If checkpoint files don't exist

93

"""

94

```

95

96

### GPT-2 Weight Conversion

97

98

```python { .api }

99

def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):

100

"""

101

Load TensorFlow GPT-2 checkpoint into PyTorch model.

102

103

Args:

104

model: PyTorch GPT-2 model instance

105

gpt2_checkpoint_path (str): Path to TensorFlow GPT-2 checkpoint

106

107

Returns:

108

PyTorch model with loaded TensorFlow weights

109

110

Raises:

111

ValueError: If checkpoint format is incompatible

112

FileNotFoundError: If checkpoint file doesn't exist

113

"""

114

```

115

116

### Transformer-XL Weight Conversion

117

118

```python { .api }

119

def load_tf_weights_in_transfo_xl(model, config, tf_path):

120

"""

121

Load TensorFlow Transformer-XL checkpoint into PyTorch model.

122

123

Args:

124

model: PyTorch Transformer-XL model instance

125

config: TransfoXLConfig instance

126

tf_path (str): Path to TensorFlow checkpoint

127

128

Returns:

129

PyTorch model with loaded TensorFlow weights

130

131

Raises:

132

ValueError: If checkpoint format is incompatible

133

FileNotFoundError: If checkpoint file doesn't exist

134

"""

135

```

136

137

138

## Usage Examples

139

140

### Basic File Caching

141

142

```python

143

from pytorch_pretrained_bert import cached_path

144

145

# Download and cache from URL

146

model_url = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"

147

local_path = cached_path(model_url)

148

print(f"Model cached at: {local_path}")

149

150

# Use local file path (returns as-is)

151

local_file = "/path/to/local/model.bin"

152

path = cached_path(local_file)

153

print(f"Local file: {path}")

154

155

# Use custom cache directory

156

custom_cache = cached_path(model_url, cache_dir="./my_models/")

157

print(f"Cached in custom directory: {custom_cache}")

158

```

159

160

### Converting TensorFlow Models

161

162

```python

163

from pytorch_pretrained_bert import (

164

BertModel, BertConfig, load_tf_weights_in_bert,

165

OpenAIGPTModel, OpenAIGPTConfig, load_tf_weights_in_openai_gpt

166

)

167

168

# Convert BERT from TensorFlow

169

bert_config = BertConfig.from_json_file("bert_config.json")

170

bert_model = BertModel(bert_config)

171

load_tf_weights_in_bert(bert_model, "bert_model.ckpt")

172

173

# Save converted model

174

torch.save(bert_model.state_dict(), "pytorch_bert_model.bin")

175

176

# Convert OpenAI GPT from TensorFlow

177

gpt_config = OpenAIGPTConfig.from_json_file("openai_gpt_config.json")

178

gpt_model = OpenAIGPTModel(gpt_config)

179

load_tf_weights_in_openai_gpt(gpt_model, "./openai_gpt_checkpoint/")

180

181

# Save converted model

182

torch.save(gpt_model.state_dict(), "pytorch_openai_gpt.bin")

183

```

184

185

### Custom Cache Management

186

187

```python

188

from pytorch_pretrained_bert import (

189

PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME

190

)

191

import os

192

193

# Check default cache directory

194

cache_dir = os.path.expanduser(PYTORCH_PRETRAINED_BERT_CACHE)

195

print(f"Default cache: {cache_dir}")

196

197

# Check for standard model files

198

model_weights = os.path.join(cache_dir, WEIGHTS_NAME)

199

model_config = os.path.join(cache_dir, CONFIG_NAME)

200

print(f"Expected weights: {model_weights}")

201

print(f"Expected config: {model_config}")

202

```

203

204

205

206

207

### Error Handling

208

209

```python

210

from pytorch_pretrained_bert import cached_path

211

import os

212

213

def safe_download(url_or_path, cache_dir=None):

214

"""Safely download or access file with error handling."""

215

try:

216

path = cached_path(url_or_path, cache_dir=cache_dir)

217

if os.path.exists(path):

218

size = os.path.getsize(path)

219

print(f"Successfully accessed: {path} ({size} bytes)")

220

return path

221

else:

222

print(f"File not found: {path}")

223

return None

224

except EnvironmentError as e:

225

print(f"Error accessing {url_or_path}: {e}")

226

return None

227

except Exception as e:

228

print(f"Unexpected error: {e}")

229

return None

230

231

# Test with various inputs

232

test_files = [

233

"https://invalid-url.com/model.bin", # Invalid URL

234

"/nonexistent/path/model.bin", # Nonexistent local path

235

"https://github.com/", # Valid URL, invalid model

236

]

237

238

for test_file in test_files:

239

result = safe_download(test_file)

240

print(f"Result for {test_file}: {result}\n")

241

```