|
| 1 | +--- |
| 2 | +title: 分割任务中常用的损失函数 |
| 3 | +icon: file |
| 4 | +category: |
| 5 | + - tools |
| 6 | +tag: |
| 7 | + - 已发布 |
| 8 | +footer: 技术共建,知识共享 |
| 9 | +date: 2025-06-11 |
| 10 | +author: |
| 11 | + - BinaryOracle |
| 12 | +--- |
| 13 | + |
| 14 | +`分割任务中常用的损失函数` |
| 15 | + |
| 16 | +<!-- more --> |
| 17 | + |
| 18 | +# 语义分割 |
| 19 | + |
| 20 | +语义分割是计算机视觉领域中的一项任务,旨在将图像中的每个像素分类为不同的语义类别。与对象检测任务不同,语义分割不仅需要识别图像中的物体,还需要对每个像素进行分类,从而实现对图像的细粒度理解和分析。 |
| 21 | + |
| 22 | +语义分割可以被看作是像素级别的图像分割,其目标是为图像中的每个像素分配一个特定的语义类别标签。每个像素都被视为图像的基本单位,因此语义分割可以提供更详细和准确的图像分析结果。 |
| 23 | + |
| 24 | +***语义分割 vs 分类 :*** |
| 25 | + |
| 26 | +1. 在语义分割任务中,由于需要对每个像素进行分类,因此需要使用像素级别的损失函数。 |
| 27 | + |
| 28 | +2. 语义分割任务中,图像中各个类别的像素数量通常不均衡,例如背景像素可能占据了大部分。 |
| 29 | + |
| 30 | +3. 语义分割任务需要对图像中的每个像素进行分类,同时保持空间连续性。 |
| 31 | + |
| 32 | +# 损失函数 |
| 33 | + |
| 34 | +## Dice Loss |
| 35 | + |
| 36 | +Dice Loss 是一种常用于语义分割任务的损失函数,尤其在目标区域较小、类别不平衡(class imbalance)的情况下表现优异。它来源于 Dice 系数(Dice Coefficient) ,又称为 Sørensen-Dice 系数 ,是衡量两个样本集合之间重叠程度的一种指标。 |
| 37 | + |
| 38 | +Dice 系数衡量的是预测掩码与真实标签之间的相似性,公式如下: |
| 39 | + |
| 40 | +$$ |
| 41 | +Dice = \frac{2|X \cap Y|}{|X| + |Y|} |
| 42 | +$$ |
| 43 | + |
| 44 | +其中: |
| 45 | + |
| 46 | +- $X$ :模型预测出的功能区域(如经过 sigmoid 后的概率值); |
| 47 | + |
| 48 | +- $Y$ :Ground Truth 掩码(二值化或软标签); |
| 49 | + |
| 50 | +- $∣X∩Y∣$ :预测为正类且实际也为正类的部分(交集); |
| 51 | + |
| 52 | +- $∣X∣+∣Y∣$ :预测和真实中所有正类区域之和; |
| 53 | + |
| 54 | +> ⚠️ 注意:Dice 系数范围是 [0, 1],越大越好。 |
| 55 | +
|
| 56 | + |
| 57 | +Dice Loss 为了将其作为损失函数使用,我们通常取其补集: |
| 58 | + |
| 59 | +$$ |
| 60 | +Dice = 1−Dice |
| 61 | +$$ |
| 62 | + |
| 63 | +有时也会加入一个平滑项 ϵ 防止除以零: |
| 64 | + |
| 65 | +$$ |
| 66 | +L_{Dice} = 1 - \frac{2\sum(X \cdot Y) + \epsilon}{\sum X + \sum Y + \epsilon} |
| 67 | +$$ |
| 68 | + |
| 69 | +Dice Loss 的优势: |
| 70 | + |
| 71 | +| 优势 | 描述 | |
| 72 | +| --- | --- | |
| 73 | +| 对类别不平衡不敏感,更关注“有没有覆盖正确区域”,而不是“有多少点被正确分类” | 不像 BCE Loss 那样对负样本过多敏感 | |
| 74 | +| 直接优化 IoU 的替代指标 | Dice 和 IoU 表现类似,但更易梯度下降 | |
| 75 | +| 支持 soft mask 输入 | 可处理连续概率值,不需要先 threshold | |
| 76 | +| 更关注整体区域匹配 | 而不是逐点分类 | |
| 77 | + |
| 78 | + |
| 79 | + |
| 80 | +代码实现: |
| 81 | + |
| 82 | +```python |
| 83 | +class DiceLoss(nn.Module): |
| 84 | + def __init__(self, weight=None, size_average=True): |
| 85 | + super(DiceLoss, self).__init__() |
| 86 | + |
| 87 | + def forward(self, inputs, targets, smooth=1): |
| 88 | + |
| 89 | + #comment out if your model contains a sigmoid or equivalent activation layer |
| 90 | + inputs = F.sigmoid(inputs) |
| 91 | + |
| 92 | + #flatten label and prediction tensors |
| 93 | + inputs = inputs.view(-1) |
| 94 | + targets = targets.view(-1) |
| 95 | + |
| 96 | + intersection = (inputs * targets).sum() |
| 97 | + dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) |
| 98 | + |
| 99 | + return 1 - dice |
| 100 | +``` |
0 commit comments