1
1
---
2
- title : 分割任务中常用的损失函数
2
+ title : 语义分割中常用的损失函数
3
3
icon : file
4
4
category :
5
5
- tools
@@ -11,11 +11,11 @@ author:
11
11
- BinaryOracle
12
12
---
13
13
14
- ` 分割任务中常用的损失函数 `
14
+ ` 语义分割中常用的损失函数 `
15
15
16
16
<!-- more -->
17
17
18
- # 语义分割
18
+ ## 语义分割
19
19
20
20
语义分割是计算机视觉领域中的一项任务,旨在将图像中的每个像素分类为不同的语义类别。与对象检测任务不同,语义分割不仅需要识别图像中的物体,还需要对每个像素进行分类,从而实现对图像的细粒度理解和分析。
21
21
@@ -29,9 +29,9 @@ author:
29
29
30
30
3 . 语义分割任务需要对图像中的每个像素进行分类,同时保持空间连续性。
31
31
32
- # 损失函数
32
+ ## 损失函数
33
33
34
- ## Dice Loss
34
+ ### Dice Loss
35
35
36
36
Dice Loss 是一种常用于语义分割任务的损失函数,尤其在目标区域较小、类别不平衡(class imbalance)的情况下表现优异。它来源于 Dice 系数(Dice Coefficient) ,又称为 Sørensen-Dice 系数 ,是衡量两个样本集合之间重叠程度的一种指标。
37
37
@@ -125,7 +125,7 @@ class DiceLoss(nn.Module):
125
125
# 值越小表示匹配越好
126
126
return 1 - dice_score
127
127
```
128
- ## BCE-Dice Loss
128
+ ### BCE-Dice Loss
129
129
130
130
BCE-Dice Loss是将Dice Loss和标准的二元交叉熵(Binary Cross-Entropy, BCE)损失结合在一起的一种损失函数,通常用于分割模型中。它结合了两种 loss 的优点:
131
131
@@ -252,22 +252,294 @@ class DiceBCELoss(nn.Module):
252
252
return Dice_BCE
253
253
```
254
254
255
- ## Jaccard/Intersection over Union (IoU) Loss
255
+ ### Jaccard/Intersection over Union (IoU) Loss
256
256
257
+ Jaccard Loss,也称为Intersection over Union (IoU) Loss,是一种常用的损失函数,用于语义分割任务中评估模型的分割结果与真实分割标签之间的相似性。它基于Jaccard指数(Jaccard Index),也称为 交并比(Intersection over Union, IoU)指标,用于度量两个集合之间的重叠程度。
257
258
259
+ 1 . ** Jaccard Index(IoU)**
258
260
259
- ## Focal Loss
261
+ $$
262
+ \text{IoU} = \frac{|X \cap Y|}{|X \cup Y|}
263
+ = \frac{\sum (\hat{y}_i \cdot y_i)}{\sum \hat{y}_i + \sum y_i - \sum (\hat{y}_i \cdot y_i)}
264
+ $$
265
+
266
+ 其中:
267
+
268
+ - $\hat{y}_ i$:模型输出的概率值或二值化结果;
269
+ - $y_i$:ground truth 掩码;
270
+ - 分子是预测和 GT 的交集;
271
+ - 分母是两者的并集;
272
+
273
+ > ⚠️ IoU 值 ∈ [ 0, 1] ,越大越好。
274
+
275
+ ---
276
+
277
+ 2 . ** Jaccard Loss(IoU Loss)**
278
+
279
+ 为了将 IoU 转换为可优化的损失函数,我们取其补集:
280
+
281
+ $$
282
+ \mathcal{L}_{\text{IoU}} = 1 - \text{IoU}
283
+ $$
284
+
285
+ 这样,损失越小表示预测越接近真实标签。
286
+
287
+ 为了避免除以零,通常加入平滑项 $\epsilon$:
288
+
289
+ $$
290
+ \mathcal{L}_{\text{IoU}} = 1 - \frac{\sum (\hat{y}_i \cdot y_i) + \epsilon}{\sum \hat{y}_i + \sum y_i - \sum (\hat{y}_i \cdot y_i) + \epsilon}
291
+ $$
292
+
293
+ ---
294
+
295
+ 3 . Jaccard Loss 有以下几个优点:
296
+
297
+ | 特性 | 描述 |
298
+ | ------| ------|
299
+ | ✔️ 对类别不平衡不敏感 | 不像 BCE Loss 那样偏向背景点 |
300
+ | ✔️ 关注整体区域匹配 | 强调预测与 GT 的空间一致性 |
301
+ | ✔️ 更适合评估边界模糊区域 | 如功能区域边缘不确定性较高 |
302
+
303
+ ---
304
+
305
+ 4 . 与其他 Loss 的对比
306
+
307
+ | 损失函数 | 是否支持 soft mask | 是否对类别不平衡敏感 | 是否直接优化 IoU | 输出范围 |
308
+ | ----------| -------------------| ------------------------| ---------------------| ------------|
309
+ | ** BCE Loss** | ❌ 否(需二值化) | ✅ 是 | ❌ 否 | [ 0, ∞) |
310
+ | ** Focal Loss** | ✅ 是(加权) | ✅ 是(缓解) | ❌ 否 | [ 0, ∞) |
311
+ | ** Dice Loss** | ✅ 是 | ✅ 是 | 近似于 IoU | [ 0, 1] |
312
+ | ** Jaccard (IoU) Loss** | ✅ 是 | ✅ 是 | ✅ 是 | [ 0, 1] |
313
+
314
+ 虽然 Dice Loss 在实际训练中更稳定,但 Jaccard Loss 更贴近最终评估指标(IoU),适合在推理阶段作为验证标准。
315
+
316
+ ---
317
+
318
+ 代码实现:
319
+
320
+ ``` python
321
+ class IoULoss (nn .Module ):
322
+ def __init__ (self , weight = None , size_average = True ):
323
+ """
324
+ 初始化函数,构建一个基于 IoU(交并比)的损失函数。
325
+
326
+ 参数:
327
+ weight (Tensor): 可选参数,用于类别加权(未使用)
328
+ size_average (bool): 是否对 batch 内样本取平均 loss(已弃用)
329
+ """
330
+ super (IoULoss, self ).__init__ ()
331
+ # weight 和 size_average 在此实现中未使用,保留接口以备后续扩展
332
+
333
+ def forward (self , inputs , targets , smooth = 1 ):
334
+ """
335
+ 前向传播函数,计算预测输出与真实标签之间的 IoU Loss。
336
+
337
+ 参数:
338
+ inputs (Tensor): 模型输出的原始 logit 或经过 sigmoid 的概率值;
339
+ 形状为 [B, N]
340
+ targets (Tensor): ground truth 掩码,形状为 [B, N]
341
+ smooth (float): 平滑项,防止除零错误,默认为 1
342
+
343
+ 返回:
344
+ iou_loss (Tensor): 计算得到的 IoU Loss
345
+ """
346
+
347
+ # 如果模型最后没有 sigmoid 层,则在这里激活
348
+ # 如果已经包含 sigmoid,则应注释掉这一行
349
+ inputs = torch.sigmoid(inputs) # 将输入映射到 [0,1] 区间
350
+
351
+ # 将输入和目标展平成一维张量便于计算
352
+ # inputs: [B*N]
353
+ # targets: [B*N]
354
+ inputs = inputs.view(- 1 )
355
+ targets = targets.view(- 1 )
356
+
357
+ # 计算交集(Intersection),等价于 TP(True Positive)
358
+ intersection = (inputs * targets).sum()
359
+
360
+ # 计算并集:Union = input + target - intersection
361
+ total = (inputs + targets).sum()
362
+ union = total - intersection
363
+
364
+ # 计算 IoU Score,加入平滑项防止除以零
365
+ iou_score = (intersection + smooth) / (union + smooth)
366
+
367
+ # IoU Loss = 1 - IoU score,这样越接近 1,loss 越小
368
+ iou_loss = 1 . - iou_score
369
+
370
+ return iou_loss
371
+ ```
372
+
373
+ ### Focal Loss
374
+
375
+ Focal Loss 是一种针对类别不平衡(Class Imbalance)问题的损失函数改进方案,由何恺明团队在2017年论文《Focal Loss for Dense Object Detection》中提出,主要用于解决目标检测任务中前景-背景类别极端不平衡的问题(如1:1000)。其核心思想是** 通过调整难易样本的权重,使模型更关注难分类的样本** 。
376
+
377
+ Focal Loss 基于交叉熵损失进行扩展,将样本的权重进行动态调整。与交叉熵损失函数相比,Focal Loss引入了一个衰减因子$(1 - pt)^\gamma$,其中 pt 是预测的概率值。这个衰减因子能够使得易分类的样本( pt较高 )的权重降低,从而减少对分类正确样本的贡献。
378
+
379
+ ** 核心思想:**
380
+
381
+ ** (1) 类别不平衡的问题**
382
+
383
+ 在分类任务中(尤其是目标检测),负样本(背景)往往远多于正样本(目标),导致:
384
+
385
+ - 模型被大量简单负样本主导,难以学习有效特征。
386
+
387
+ - 简单样本的梯度贡献淹没难样本的梯度。
388
+
389
+ ** (2) Focal Loss 的改进**
390
+
391
+ - ** 降低易分类样本的权重** :对模型已经分类正确的样本(高置信度)减少损失贡献。
392
+
393
+ - ** 聚焦难分类样本** :对分类错误的样本(低置信度)保持高损失权重。
394
+
395
+ ---
396
+
397
+ Focal Loss 基于标准交叉熵损失(Cross-Entropy Loss)改进而来。
398
+
399
+ ** (1) 标准交叉熵损失(CE Loss)**
400
+
401
+ $$
402
+ CE(p, y) =
403
+ \begin{cases}
404
+ -\log(p) & \text{if } y=1 \\
405
+ -\log(1-p) & \text{if } y=0
406
+ \end{cases}
407
+ $$
408
+
409
+ 其中:
410
+ - p 是模型预测的概率(经过sigmoid/softmax)。
411
+ - y 是真实标签(0或1)。
412
+
413
+ ** (2) Focal Loss 定义**
414
+
415
+ $$
416
+ FL(p, y) =
417
+ \begin{cases}
418
+ -\alpha (1-p)^\gamma \log(p) & \text{if } y=1 \\
419
+ -(1-\alpha) p^\gamma \log(1-p) & \text{if } y=0
420
+ \end{cases}
421
+ $$
422
+ - ** $\alpha$** :类别平衡权重(通常$\alpha \in [ 0,1] $),用于平衡正负样本数量差异。
423
+ - ** $\gamma$** :调节因子(通常$\gamma \geq 0$),控制难易样本的权重衰减程度。
424
+
425
+ > γ 参数用于抑制容易分类的样本,而 α 参数用于平衡正负类别的权重。两者解决的是不同维度的问题:
426
+ >
427
+ > - α:防止前景点(功能区域)被背景淹没,解决数据集中“类别数量不平衡”的问题(数据集级别);
428
+ >
429
+ > - γ:防止模型只关注简单样本,忽略难分类样本,解决模型训练时“简单样本主导梯度”的问题(样本级别);
430
+ >
431
+ > 综上,先通过 α 平衡类别数量,再通过 γ 抑制简单样本,两者协同提升模型性能。
432
+
433
+ ---
434
+
435
+ ** 关键参数的作用:**
436
+
437
+ | 参数 | 作用 | 典型值 |
438
+ | -----------| ----------------------------------------------------------------------| --------------|
439
+ | ** $\gamma$** | 控制难易样本权重:<br >• $\gamma=0$:退化为CE Loss<br >• $\gamma=2$:显著抑制简单样本 | 0.5 ~ 5 |
440
+ | ** $\alpha$** | 平衡正负样本数量:<br >• $\alpha=0.75$:正样本较少时增加权重 | 0.25 ~ 0.75 |
260
441
442
+ ** 难样本vs易样:**
443
+
444
+ - ** 易分类样本** (如 p=0.9 ): $(1-p)^\gamma$ 接近0,损失被大幅降低。
445
+
446
+ - ** 难分类样本** (如 p=0.1 ): $(1-p)^\gamma$ 接近1,损失几乎不受影响。
447
+
448
+ > 假设两个正样本:
449
+ >
450
+ > 1 . ** 易样本** :$p=0.9$(模型已自信分类)
451
+ >
452
+ > - 标准 CE Loss:$-\log(0.9) \approx 0.105$
453
+ >
454
+ > - Focal Loss($\gamma=2$):$(1-0.9)^2 \times 0.105 \approx 0.001$ ** 损失权重降低 100 倍** !
455
+ >
456
+ > 2 . ** 难样本** :$p=0.1$(模型分类错误)
457
+ >
458
+ > - 标准 CE Loss:$-\log(0.1) \approx 2.302$
459
+ >
460
+ > - Focal Loss($\gamma=2$):$(1-0.1)^2 \times 2.302 \approx 1.866$ ** 损失权重仅降低 20%** 。
461
+
462
+ ** 应用场景:**
463
+
464
+ 1 . ** 目标检测** (如RetinaNet): 解决前景(目标)与背景的极端不平衡问题。
465
+
466
+ 2 . ** 医学图像分割** : 病灶区域像素远少于正常组织。
467
+
468
+ 3 . ** 任何类别不平衡的分类任务** : 如欺诈检测、罕见疾病诊断等。
469
+
470
+ ** 优缺点:**
471
+
472
+ | ** 优点** | ** 缺点** |
473
+ | ------------------------------| ------------------------------|
474
+ | 显著提升难样本的分类性能 | 需调参($\alpha, \gamma$)|
475
+ | 抑制简单样本的梯度主导 | 对噪声标签敏感 |
476
+ | 兼容大多数分类模型 | 计算量略高于CE Loss |
477
+
478
+
479
+ - ** Focal Loss 通过 $(1-p)^\gamma$ 动态调整样本权重** ,使模型聚焦难分类样本。
480
+
481
+ - ** 参数选择** :
482
+
483
+ - $\gamma$:一般从2开始调优(值越大,简单样本抑制越强)。
484
+
485
+ - $\alpha$:根据正负样本比例调整(如正样本少则增大 $\alpha$)。
486
+
487
+ - ** 适用场景** :类别不平衡越严重,Focal Loss 效果越显著。
488
+
489
+ ---
490
+
491
+ 代码实现:
492
+
493
+ ``` python
494
+ class FocalLoss (nn .Module ):
495
+ def __init__ (self , alpha = 0.8 , gamma = 2 , reduction = ' mean' ):
496
+ super (FocalLoss, self ).__init__ ()
497
+ self .alpha = alpha
498
+ self .gamma = gamma
499
+ self .reduction = reduction
500
+
501
+ def forward (self , inputs , targets ):
502
+ """
503
+ inputs: raw logits [B, N]
504
+ targets: ground truth mask [B, N] (soft or binary)
505
+ """
506
+
507
+ # Step 1: Sigmoid 映射
508
+ probs = torch.sigmoid(inputs)
509
+
510
+ # Step 2: 根据 target 构建 pt
511
+ pt = probs * targets + (1 - probs) * (1 - targets)
512
+
513
+ # Step 3: 构建 at(正负样本权重)
514
+ at = self .alpha * targets + (1 - self .alpha) * (1 - targets)
515
+
516
+ # Step 4: 计算 Focal Weight
517
+ focal_weight = torch.pow(1 - pt, self .gamma)
518
+
519
+ # Step 5: BCE Loss(binary_cross_entropy_with_logits 已包含 sigmoid)
520
+ BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction = ' none' )
521
+
522
+ # Step 6: 应用 Focal Weight
523
+ focal_loss = at * focal_weight * BCE
524
+
525
+ # Step 7: 返回结果
526
+ if self .reduction == ' mean' :
527
+ return focal_loss.mean()
528
+ elif self .reduction == ' sum' :
529
+ return focal_loss.sum()
530
+ else :
531
+ return focal_loss
532
+ ```
261
533
262
- ## Tversky Loss
534
+ ### Tversky Loss
263
535
264
536
265
537
266
- ## Lovasz Hinge Loss
538
+ ### Lovasz Hinge Loss
267
539
268
540
269
541
270
- ## Combo Loss
542
+ ### Combo Loss
271
543
272
544
273
545
0 commit comments