or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcallbacks.mdcore-training.mddata.mdfabric.mdindex.mdloggers.mdprecision.mdprofilers.mdstrategies.md

precision.mddocs/

0

# Precision Control and Optimization

1

2

Precision plugins for mixed precision training, quantization, and various floating-point formats to optimize memory usage and training speed while maintaining model quality.

3

4

## Capabilities

5

6

### Mixed Precision Training

7

8

Automatic mixed precision training using 16-bit floats for forward pass and 32-bit for loss computation.

9

10

```python { .api }

11

class MixedPrecision:

12

def __init__(self, precision: str = "16-mixed", device: str = "cuda"):

13

"""

14

Initialize mixed precision plugin.

15

16

Args:

17

precision: Precision mode ('16-mixed', 'bf16-mixed')

18

device: Target device

19

"""

20

```

21

22

### Half Precision

23

24

16-bit floating point training for memory efficiency.

25

26

```python { .api }

27

class HalfPrecision:

28

def __init__(self):

29

"""Initialize half precision plugin."""

30

```

31

32

### Double Precision

33

34

64-bit floating point training for maximum numerical precision.

35

36

```python { .api }

37

class DoublePrecision:

38

def __init__(self):

39

"""Initialize double precision plugin."""

40

```

41

42

### Quantization

43

44

8-bit and 4-bit quantization using BitsAndBytes for memory-efficient training of large models.

45

46

```python { .api }

47

class BitsandbytesPrecision:

48

def __init__(

49

self,

50

mode: str = "int8",

51

dtype: Optional[torch.dtype] = None,

52

ignore_modules: Optional[Set[str]] = None

53

):

54

"""

55

Initialize BitsAndBytes precision plugin.

56

57

Args:

58

mode: Quantization mode ('int8', 'int4', 'nf4', 'fp4')

59

dtype: Data type for computation

60

ignore_modules: Modules to skip quantization

61

"""

62

```

63

64

### DeepSpeed Precision

65

66

Precision plugin for DeepSpeed optimization with ZeRO memory optimization.

67

68

```python { .api }

69

class DeepSpeedPrecision:

70

def __init__(self):

71

"""Initialize DeepSpeed precision plugin."""

72

```

73

74

### FSDP Precision

75

76

Precision plugin optimized for Fully Sharded Data Parallel training.

77

78

```python { .api }

79

class FSDPPrecision:

80

def __init__(self):

81

"""Initialize FSDP precision plugin."""

82

```

83

84

### Transformer Engine Precision

85

86

NVIDIA Transformer Engine precision for optimized transformer training.

87

88

```python { .api }

89

class TransformerEnginePrecision:

90

def __init__(

91

self,

92

weights_dtype: torch.dtype = torch.float32,

93

recipe: Optional[Dict[str, Any]] = None

94

):

95

"""

96

Initialize Transformer Engine precision plugin.

97

98

Args:

99

weights_dtype: Data type for model weights

100

recipe: Transformer Engine recipe configuration

101

"""

102

```

103

104

### XLA Precision

105

106

Precision plugin for TPU training with XLA compilation.

107

108

```python { .api }

109

class XLAPrecision:

110

def __init__(self):

111

"""Initialize XLA precision plugin for TPU training."""

112

```

113

114

### Base Precision

115

116

Base class for implementing custom precision plugins.

117

118

```python { .api }

119

class Precision:

120

def __init__(self):

121

"""Initialize base precision plugin."""

122

123

def convert_module(self, module: nn.Module) -> nn.Module:

124

"""Convert module for precision."""

125

126

def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:

127

"""Convert optimizer for precision."""

128

129

def backward(self, tensor: Tensor, model: nn.Module) -> None:

130

"""Perform backward pass with precision handling."""

131

```