or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mdfeatures.mdindex.mdlayers.mdmodels.mdtraining.mdutils.md

utils.mddocs/

0

# Utilities and Helpers

1

2

General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.

3

4

## Capabilities

5

6

### Model Utilities

7

8

Functions for model management, parameter manipulation, and model state operations.

9

10

```python { .api }

11

def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:

12

"""

13

Unwrap model from DDP/EMA/other wrappers.

14

15

Args:

16

model: Wrapped model instance

17

18

Returns:

19

Unwrapped base model

20

"""

21

22

def get_state_dict(

23

model: torch.nn.Module,

24

unwrap_fn: Callable = unwrap_model

25

) -> Dict[str, Any]:

26

"""

27

Get model state dictionary with unwrapping.

28

29

Args:

30

model: Model to get state dict from

31

unwrap_fn: Function to unwrap model

32

33

Returns:

34

Model state dictionary

35

"""

36

37

def freeze(model: torch.nn.Module) -> None:

38

"""

39

Freeze all model parameters (disable gradients).

40

41

Args:

42

model: Model to freeze

43

"""

44

45

def unfreeze(model: torch.nn.Module) -> None:

46

"""

47

Unfreeze all model parameters (enable gradients).

48

49

Args:

50

model: Model to unfreeze

51

"""

52

53

def reparameterize_model(

54

model: torch.nn.Module,

55

**kwargs

56

) -> torch.nn.Module:

57

"""

58

Reparameterize model for inference optimization.

59

60

Args:

61

model: Model to reparameterize

62

**kwargs: Reparameterization options

63

64

Returns:

65

Reparameterized model

66

"""

67

```

68

69

### Distributed Training Utilities

70

71

Functions for initializing and managing distributed training across multiple devices and nodes.

72

73

```python { .api }

74

def init_distributed_device(args) -> Tuple[torch.device, int]:

75

"""

76

Initialize distributed training device and process rank.

77

78

Args:

79

args: Arguments namespace with distributed training configuration

80

81

Returns:

82

Tuple of (device, world_size) for distributed training setup

83

"""

84

85

def distribute_bn(

86

model: torch.nn.Module,

87

world_size: int,

88

reduce: bool = False

89

) -> None:

90

"""

91

Distribute batch normalization statistics across processes.

92

93

Args:

94

model: Model with batch norm layers

95

world_size: Number of distributed processes

96

reduce: Reduce statistics across processes

97

"""

98

99

def reduce_tensor(

100

tensor: torch.Tensor,

101

world_size: int = 1

102

) -> torch.Tensor:

103

"""

104

Reduce tensor across distributed processes.

105

106

Args:

107

tensor: Tensor to reduce

108

world_size: Number of processes

109

110

Returns:

111

Reduced tensor

112

"""

113

114

def world_info_from_env() -> Tuple[int, int, int]:

115

"""

116

Get distributed world info from environment variables.

117

118

Returns:

119

Tuple of (local_rank, world_rank, world_size)

120

"""

121

122

def is_distributed_env() -> bool:

123

"""

124

Check if running in distributed environment.

125

126

Returns:

127

True if distributed environment detected

128

"""

129

```

130

131

### Mixed Precision Training

132

133

Utilities for managing mixed precision training with automatic mixed precision (AMP).

134

135

```python { .api }

136

class ApexScaler:

137

"""

138

Gradient scaler using NVIDIA Apex.

139

140

Args:

141

loss_scale: Initial loss scaling factor

142

init_scale: Initial scale value

143

scale_factor: Scale adjustment factor

144

scale_window: Scale adjustment window

145

"""

146

147

def __init__(

148

self,

149

loss_scale: str = 'dynamic',

150

init_scale: float = 2.**16,

151

scale_factor: float = 2.0,

152

scale_window: int = 2000

153

): ...

154

155

def scale_loss(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer): ...

156

def unscale_grads(self, optimizer: torch.optim.Optimizer): ...

157

def update_scale(self, overflow: bool): ...

158

159

class NativeScaler:

160

"""

161

Native PyTorch gradient scaler for mixed precision.

162

163

Args:

164

enabled: Enable gradient scaling

165

init_scale: Initial scaling factor

166

growth_factor: Scale growth factor

167

backoff_factor: Scale backoff factor

168

growth_interval: Interval for scale growth

169

"""

170

171

def __init__(

172

self,

173

enabled: bool = True,

174

init_scale: float = 2.**16,

175

growth_factor: float = 2.0,

176

backoff_factor: float = 0.5,

177

growth_interval: int = 2000

178

): ...

179

180

def scale(self, loss: torch.Tensor) -> torch.Tensor: ...

181

def step(self, optimizer: torch.optim.Optimizer) -> None: ...

182

def update(self) -> None: ...

183

```

184

185

### CUDA and Performance Utilities

186

187

Functions for managing CUDA operations, JIT compilation, and performance optimization.

188

189

```python { .api }

190

def set_jit_legacy(enable: bool) -> None:

191

"""

192

Set legacy JIT mode.

193

194

Args:

195

enable: Enable legacy JIT mode

196

"""

197

198

def set_jit_fuser(fuser_name: str) -> None:

199

"""

200

Set JIT fuser type.

201

202

Args:

203

fuser_name: Name of fuser ('te', 'old', 'nvfuser')

204

"""

205

206

def random_seed(seed: int, rank: int = 0) -> None:

207

"""

208

Set random seed for reproducibility across all libraries.

209

210

Args:

211

seed: Random seed value

212

rank: Process rank for distributed training

213

"""

214

```

215

216

### Logging and Configuration

217

218

Utilities for setting up logging, argument parsing, and experiment configuration.

219

220

```python { .api }

221

def setup_default_logging(

222

default_level: int = logging.INFO,

223

log_path: str = '',

224

**kwargs

225

) -> None:

226

"""

227

Setup default logging configuration.

228

229

Args:

230

default_level: Default logging level

231

log_path: Path for log file

232

**kwargs: Additional logging configuration

233

"""

234

235

def natural_key(string_: str) -> List[Union[int, str]]:

236

"""

237

Natural sorting key function for strings with numbers.

238

239

Args:

240

string_: String to create key for

241

242

Returns:

243

List of components for natural sorting

244

"""

245

246

def add_bool_arg(

247

parser,

248

name: str,

249

default: bool = False,

250

help: str = ''

251

) -> None:

252

"""

253

Add boolean argument to argument parser with --name/--no-name pattern.

254

255

Args:

256

parser: ArgumentParser instance

257

name: Argument name

258

default: Default value

259

help: Help text

260

"""

261

```

262

263

### Training Summary and Output

264

265

Functions for managing training outputs, experiment directories, and result summaries.

266

267

```python { .api }

268

def update_summary(

269

epoch: int,

270

train_metrics: Dict[str, float],

271

eval_metrics: Dict[str, float],

272

filename: str,

273

lr: float = None,

274

write_header: bool = False,

275

log_wandb: bool = False

276

) -> None:

277

"""

278

Update training summary with metrics.

279

280

Args:

281

epoch: Current epoch

282

train_metrics: Training metrics dictionary

283

eval_metrics: Evaluation metrics dictionary

284

filename: Summary file path

285

lr: Current learning rate

286

write_header: Write CSV header

287

log_wandb: Log to Weights & Biases

288

"""

289

290

def get_outdir(path: str, *paths: str, inc: bool = False) -> str:

291

"""

292

Get output directory for experiments.

293

294

Args:

295

path: Base output path

296

*paths: Additional path components

297

inc: Auto-increment directory name

298

299

Returns:

300

Output directory path

301

"""

302

```

303

304

## Training Monitoring Classes

305

306

### Metrics Tracking

307

308

```python { .api }

309

class AverageMeter:

310

"""

311

Computes and stores the average and current value for metrics tracking.

312

313

Args:

314

name: Name of the metric

315

fmt: Format string for display

316

"""

317

318

def __init__(self, name: str = '', fmt: str = ':f'): ...

319

320

def reset(self) -> None:

321

"""Reset all statistics to initial values."""

322

323

def update(self, val: float, n: int = 1) -> None:

324

"""

325

Update meter with new value.

326

327

Args:

328

val: New value to add

329

n: Number of samples the value represents

330

"""

331

332

def __str__(self) -> str:

333

"""String representation of current meter state."""

334

335

def accuracy(

336

output: torch.Tensor,

337

target: torch.Tensor,

338

topk: Tuple[int, ...] = (1,)

339

) -> List[torch.Tensor]:

340

"""

341

Compute accuracy for specified top-k values.

342

343

Args:

344

output: Model output predictions [batch_size, num_classes]

345

target: Ground truth labels [batch_size]

346

topk: Tuple of k values for top-k accuracy

347

348

Returns:

349

List of accuracy tensors for each k value

350

"""

351

```

352

353

### Model EMA Management

354

355

```python { .api }

356

class ModelEma:

357

"""

358

Model Exponential Moving Average for maintaining shadow weights.

359

360

Args:

361

model: Model to track with EMA

362

decay: EMA decay rate (default: 0.9999)

363

device: Device to store EMA parameters

364

resume: Path to resume EMA from checkpoint

365

"""

366

367

def __init__(

368

self,

369

model: torch.nn.Module,

370

decay: float = 0.9999,

371

device: torch.device = None,

372

resume: str = ''

373

): ...

374

375

def update(self, model: torch.nn.Module) -> None:

376

"""

377

Update EMA parameters from model.

378

379

Args:

380

model: Source model for updates

381

"""

382

383

def set(self, model: torch.nn.Module) -> None:

384

"""

385

Set EMA parameters from model (copy all parameters).

386

387

Args:

388

model: Source model to copy from

389

"""

390

391

class ModelEmaV2:

392

"""

393

Model EMA v2 with improved decay adjustment based on training progress.

394

395

Args:

396

model: Model to track

397

decay: Base decay rate

398

decay_type: Type of decay adjustment ('exponential', 'linear')

399

device: Device for EMA parameters

400

"""

401

402

def __init__(

403

self,

404

model: torch.nn.Module,

405

decay: float = 0.9999,

406

decay_type: str = 'exponential',

407

device: torch.device = None

408

): ...

409

410

class ModelEmaV3:

411

"""

412

Model EMA v3 with performance optimizations and memory efficiency.

413

414

Args:

415

model: Model to track

416

decay: EMA decay rate

417

update_after_step: Steps before starting EMA updates

418

use_ema_warmup: Use warmup for EMA updates

419

inv_gamma: Inverse gamma for warmup

420

power: Power for warmup

421

min_value: Minimum decay value

422

device: Device for parameters

423

"""

424

425

def __init__(

426

self,

427

model: torch.nn.Module,

428

decay: float = 0.9999,

429

update_after_step: int = 100,

430

use_ema_warmup: bool = False,

431

inv_gamma: float = 1.0,

432

power: float = 2/3,

433

min_value: float = 0.0,

434

device: torch.device = None

435

): ...

436

```

437

438

### Checkpoint Management

439

440

```python { .api }

441

class CheckpointSaver:

442

"""

443

Saves model checkpoints with configurable retention and recovery policies.

444

445

Args:

446

model: Model to save

447

optimizer: Optimizer state to save

448

args: Training arguments/configuration

449

model_ema: EMA model to save

450

amp_scaler: Mixed precision scaler

451

checkpoint_prefix: Prefix for checkpoint filenames

452

recovery_prefix: Prefix for recovery checkpoints

453

checkpoint_dir: Directory for regular checkpoints

454

recovery_dir: Directory for recovery checkpoints

455

decreasing: Whether monitored metric is decreasing (lower is better)

456

max_history: Maximum number of checkpoints to keep

457

unwrap_fn: Function to unwrap model before saving

458

"""

459

460

def __init__(

461

self,

462

model: torch.nn.Module,

463

optimizer: torch.optim.Optimizer,

464

args = None,

465

model_ema: ModelEma = None,

466

amp_scaler = None,

467

checkpoint_prefix: str = 'checkpoint',

468

recovery_prefix: str = 'recovery',

469

checkpoint_dir: str = '',

470

recovery_dir: str = '',

471

decreasing: bool = False,

472

max_history: int = 10,

473

unwrap_fn: Callable = unwrap_model

474

): ...

475

476

def save_checkpoint(

477

self,

478

epoch: int,

479

metric: float = None

480

) -> Tuple[str, bool]:

481

"""

482

Save checkpoint if metric improved.

483

484

Args:

485

epoch: Current epoch number

486

metric: Metric value for comparison

487

488

Returns:

489

Tuple of (checkpoint_path, is_best)

490

"""

491

492

def save_recovery(

493

self,

494

epoch: int,

495

batch_idx: int = 0

496

) -> str:

497

"""

498

Save recovery checkpoint for resuming interrupted training.

499

500

Args:

501

epoch: Current epoch

502

batch_idx: Current batch index

503

504

Returns:

505

Path to saved recovery checkpoint

506

"""

507

```

508

509

## Usage Examples

510

511

### Basic Training Setup with Utilities

512

513

```python

514

import logging

515

import timm

516

from timm.utils import (

517

setup_default_logging, random_seed, ModelEma,

518

CheckpointSaver, AverageMeter, accuracy

519

)

520

521

# Setup logging

522

setup_default_logging(log_path='training.log')

523

logger = logging.getLogger(__name__)

524

525

# Set random seed for reproducibility

526

random_seed(42, rank=0)

527

528

# Create model and training components

529

model = timm.create_model('resnet50', pretrained=True, num_classes=1000)

530

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

531

532

# Setup EMA tracking

533

model_ema = ModelEma(model, decay=0.9999)

534

535

# Setup checkpoint saving

536

saver = CheckpointSaver(

537

model=model,

538

optimizer=optimizer,

539

model_ema=model_ema,

540

checkpoint_dir='./checkpoints',

541

max_history=5,

542

decreasing=False # Higher accuracy is better

543

)

544

545

# Setup metrics tracking

546

losses = AverageMeter('Loss', ':.4e')

547

top1 = AverageMeter('Acc@1', ':6.2f')

548

top5 = AverageMeter('Acc@5', ':6.2f')

549

```

550

551

### Distributed Training Setup

552

553

```python

554

from timm.utils import (

555

init_distributed_device, distribute_bn, reduce_tensor,

556

is_distributed_env

557

)

558

559

# Initialize distributed training

560

device, world_size = init_distributed_device(args)

561

model = model.to(device)

562

563

if is_distributed_env():

564

# Synchronize batch norm statistics

565

distribute_bn(model, world_size, reduce=True)

566

567

# Wrap model for distributed training

568

model = torch.nn.parallel.DistributedDataParallel(

569

model, device_ids=[device], find_unused_parameters=False

570

)

571

572

# In training loop - reduce metrics across processes

573

def train_epoch(model, loader, optimizer, device, world_size):

574

losses = AverageMeter('Loss')

575

576

for batch_idx, (input, target) in enumerate(loader):

577

input, target = input.to(device), target.to(device)

578

579

output = model(input)

580

loss = criterion(output, target)

581

582

# Backward and optimization

583

optimizer.zero_grad()

584

loss.backward()

585

optimizer.step()

586

587

# Reduce loss across processes

588

if world_size > 1:

589

loss = reduce_tensor(loss, world_size)

590

591

losses.update(loss.item(), input.size(0))

592

593

return losses.avg

594

```

595

596

### Mixed Precision Training

597

598

```python

599

from timm.utils import NativeScaler

600

601

# Setup mixed precision training

602

scaler = NativeScaler()

603

model = model.to(device)

604

605

def train_step(model, input, target, optimizer, scaler):

606

optimizer.zero_grad()

607

608

# Forward pass with autocast

609

with torch.cuda.amp.autocast():

610

output = model(input)

611

loss = criterion(output, target)

612

613

# Backward pass with gradient scaling

614

scaler.scale(loss).backward()

615

scaler.step(optimizer)

616

scaler.update()

617

618

return loss.item()

619

```

620

621

### Complete Training Loop with Utilities

622

623

```python

624

def train_model():

625

setup_default_logging()

626

random_seed(42)

627

628

# Model setup

629

model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)

630

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

631

632

# Training utilities

633

model_ema = ModelEmaV2(model, decay=0.9999)

634

scaler = NativeScaler()

635

saver = CheckpointSaver(

636

model, optimizer, model_ema=model_ema, amp_scaler=scaler,

637

checkpoint_dir='./checkpoints'

638

)

639

640

# Metrics

641

train_losses = AverageMeter('Train Loss')

642

train_acc1 = AverageMeter('Train Acc@1')

643

644

for epoch in range(num_epochs):

645

# Training

646

model.train()

647

train_losses.reset()

648

train_acc1.reset()

649

650

for batch_idx, (input, target) in enumerate(train_loader):

651

input, target = input.to(device), target.to(device)

652

653

# Mixed precision forward pass

654

with torch.cuda.amp.autocast():

655

output = model(input)

656

loss = criterion(output, target)

657

658

# Backward pass

659

optimizer.zero_grad()

660

scaler.scale(loss).backward()

661

scaler.step(optimizer)

662

scaler.update()

663

664

# Update EMA

665

model_ema.update(model)

666

667

# Metrics

668

acc1, acc5 = accuracy(output, target, topk=(1, 5))

669

train_losses.update(loss.item(), input.size(0))

670

train_acc1.update(acc1.item(), input.size(0))

671

672

# Validation and checkpointing

673

val_acc = validate(model_ema.module, val_loader)

674

saver.save_checkpoint(epoch, val_acc)

675

676

logger.info(f'Epoch {epoch}: Train Loss {train_losses.avg:.4f}, '

677

f'Train Acc {train_acc1.avg:.2f}%, Val Acc {val_acc:.2f}%')

678

```

679

680

## Types

681

682

```python { .api }

683

from typing import Optional, Union, List, Dict, Callable, Any, Tuple

684

import torch

685

import logging

686

687

# Device and distributed types

688

DeviceType = torch.device

689

WorldInfo = Tuple[int, int, int] # (local_rank, world_rank, world_size)

690

691

# Metrics types

692

MetricValue = Union[float, torch.Tensor]

693

MetricDict = Dict[str, MetricValue]

694

695

# Checkpoint types

696

CheckpointDict = Dict[str, Any]

697

UnwrapFunction = Callable[[torch.nn.Module], torch.nn.Module]

698

699

# Scaler types

700

LossScaler = Union[torch.cuda.amp.GradScaler, Any]

701

702

# Logging types

703

LogLevel = int

704

Logger = logging.Logger

705

706

# Utility function types

707

SeedFunction = Callable[[int, int], None]

708

ReduceFunction = Callable[[torch.Tensor, int], torch.Tensor]

709

```