0
# Segmentation Metrics
1
2
Semantic and instance segmentation evaluation including Dice coefficients, Intersection over Union, and Hausdorff distance for pixel-level prediction assessment.
3
4
## Capabilities
5
6
### Dice Metrics
7
8
Dice similarity coefficient variants for segmentation evaluation.
9
10
```python { .api }
11
class DiceScore(Metric):
12
def __init__(
13
self,
14
num_classes: Optional[int] = None,
15
threshold: float = 0.5,
16
average: Optional[str] = "micro",
17
mdmc_average: Optional[str] = None,
18
ignore_index: Optional[int] = None,
19
top_k: Optional[int] = None,
20
multiclass: Optional[bool] = None,
21
**kwargs
22
): ...
23
24
class GeneralizedDiceScore(Metric):
25
def __init__(
26
self,
27
num_classes: Optional[int] = None,
28
include_background: bool = True,
29
weight_type: str = "square",
30
**kwargs
31
): ...
32
```
33
34
### Intersection over Union
35
36
IoU-based metrics for segmentation quality assessment.
37
38
```python { .api }
39
class MeanIoU(Metric):
40
def __init__(
41
self,
42
num_classes: int,
43
ignore_index: Optional[int] = None,
44
absent_score: float = 0.0,
45
threshold: float = 0.5,
46
**kwargs
47
): ...
48
```
49
50
### Distance-Based Metrics
51
52
Metrics based on spatial distances between segmentation boundaries.
53
54
```python { .api }
55
class HausdorffDistance(Metric):
56
def __init__(
57
self,
58
percentile: Optional[float] = None,
59
directed: bool = False,
60
**kwargs
61
): ...
62
```
63
64
## Usage Examples
65
66
```python
67
import torch
68
from torchmetrics.segmentation import DiceScore, MeanIoU, HausdorffDistance
69
70
# Binary segmentation
71
dice_binary = DiceScore()
72
preds_binary = torch.rand(4, 1, 128, 128) # Batch, channels, height, width
73
target_binary = torch.randint(0, 2, (4, 1, 128, 128)).float()
74
75
dice_result = dice_binary(preds_binary, target_binary)
76
print(f"Binary Dice Score: {dice_result:.4f}")
77
78
# Multi-class segmentation
79
num_classes = 5
80
iou_metric = MeanIoU(num_classes=num_classes)
81
preds_multi = torch.rand(2, num_classes, 64, 64)
82
target_multi = torch.randint(0, num_classes, (2, 64, 64))
83
84
iou_result = iou_metric(preds_multi, target_multi)
85
print(f"Mean IoU: {iou_result:.4f}")
86
87
# Hausdorff distance (requires binary masks)
88
hausdorff = HausdorffDistance()
89
binary_preds = (torch.rand(2, 1, 32, 32) > 0.5).float()
90
binary_targets = torch.randint(0, 2, (2, 1, 32, 32)).float()
91
92
hd_result = hausdorff(binary_preds, binary_targets)
93
print(f"Hausdorff Distance: {hd_result:.4f}")
94
```
95
96
## Types
97
98
```python { .api }
99
SegmentationMask = Tensor # Binary or multi-class segmentation masks
100
WeightType = Union["square", "simple", "uniform"]
101
```