@@ -677,12 +677,100 @@ _3daffordance = torch.sigmoid(_3daffordance)
677
677
678
678
## 损失函数
679
679
680
- ### Focal Loss
680
+ ### HM_Loss(Hybrid Mask Loss)
681
681
682
+ 在 LASO 数据集中,模型需要根据自然语言问题识别点云中最相关的功能区域(如 grasping area, opening area 等),而 HM_Loss 是 PointRefer 模型的监督信号,它结合了:
682
683
684
+ - Focal Loss :用于缓解类别不平衡问题;
683
685
684
- ### Dice Loss
686
+ - Dice Loss :用于衡量预测掩码与真实标签之间的空间重合度;
685
687
688
+ 最终 loss = CELoss + DiceLoss,让模型同时关注逐点分类精度和整体区域匹配。
689
+
690
+ ``` python
691
+ import torch
692
+ import torch.nn as nn
693
+ import torch.nn.functional as F
694
+
695
+
696
+ class HM_Loss (nn .Module ):
697
+ def __init__ (self ):
698
+ """
699
+ Hybrid Mask Loss 实现:
700
+ - BCE-Focal Loss(加权交叉熵)
701
+ - Dice Loss(衡量预测掩码与 GT 的重合度)
702
+
703
+ 公式来自论文 Section 4.2,用于语言引导下的功能区域分割。
704
+ """
705
+ super (HM_Loss , self ).__init__ ()
706
+ # 设置 Focal Loss 参数
707
+ self .gamma = 2 # 聚焦参数,放大难分类样本影响
708
+ self .alpha = 0.25 # 平衡因子,强调正类(前景点)loss
709
+
710
+ def forward (self , pred , target ):
711
+ """
712
+ 输入:
713
+ pred: 模型输出的原始 logit 或经过 sigmoid 的概率值;
714
+ 形状为 [B, N]
715
+ target: ground truth 掩码(soft mask),形状也为 [B, N]
716
+
717
+ 返回:
718
+ total_loss: CELoss + DiceLoss 的加权和
719
+ """
720
+
721
+ # Step 1: 构建 Focal Loss 权重项
722
+ # temp1:负类 loss(背景点)
723
+ # temp2:正类 loss(目标功能区域)
724
+ # 1e-6 的加入是为了让 log 计算保持稳定,尤其是在预测值接近极端值(0 或 1)时
725
+ temp1 = - (1 - self .alpha) * torch.mul(
726
+ pred ** self .gamma,
727
+ torch.mul(1 - target, torch.log(1 - pred + 1e-6 ))
728
+ )
729
+ temp2 = - self .alpha * torch.mul(
730
+ (1 - pred) ** self .gamma,
731
+ torch.mul(target, torch.log(pred + 1e-6 ))
732
+ )
733
+
734
+ # 将两个方向的 loss 合并,并取 batch 和点维度的平均
735
+ temp = temp1 + temp2
736
+ CELoss = torch.sum(torch.mean(temp, dim = (0 , 1 )))
737
+
738
+ # Step 2: 计算正类 Dice Loss(预测与 Ground Truth 的交集 / 并集)
739
+ intersection_positive = torch.sum(pred * target, dim = 1 )
740
+ cardinality_positive = torch.sum(torch.abs(pred) + torch.abs(target), dim = 1 )
741
+ dice_positive = (intersection_positive + 1e-6 ) / (cardinality_positive + 1e-6 )
742
+
743
+ # Step 3: 计算负类 Dice Loss(非目标区域匹配度)
744
+ intersection_negative = torch.sum((1 - pred) * (1 - target), dim = 1 )
745
+ cardinality_negative = torch.sum(2 - torch.abs(pred) - torch.abs(target), dim = 1 )
746
+ dice_negative = (intersection_negative + 1e-6 ) / (cardinality_negative + 1e-6 )
747
+
748
+ # Step 4: 构建 Dice Loss,形式为 1 - Dice Score
749
+ # 使用了一个偏置项 1.5(可能是经验设定)
750
+ temp3 = torch.mean(1.5 - dice_positive - dice_negative, dim = 0 )
751
+ DICELoss = torch.sum(temp3)
752
+
753
+ # Step 5: 总损失 = 分类误差 + 区域匹配误差
754
+ return CELoss + 1.0 * DICELoss
755
+ ```
756
+
757
+ 在论文 Section 4.2 中提到:
758
+
759
+ > “We solely employ Dice loss and Binary Cross-Entropy (BCE) loss to guide the segmentation mask prediction.”
760
+
761
+ 虽然这里用的是 Focal Loss + Dice Loss 的组合形式,但它本质上是 BCE + Dice 的改进版,具有以下优势:
762
+
763
+ - Focal Loss: 抑制 easy examples,放大 hard examples,防止忽略小区域
764
+
765
+ - Dice Loss: 关注整体掩码匹配度,提升边界识别能力
766
+
767
+ 两者结合可以:
768
+
769
+ - 缓解类别极度不平衡问题;
770
+
771
+ - 提高模型对语言指令下功能区域的理解能力;
772
+
773
+ - 更好地应对 LASO 中的语言引导 + soft mask 场景;
686
774
687
775
## 训练
688
776
0 commit comments