Skip to content

Commit 2834c12

Browse files
committed
updates
1 parent e113919 commit 2834c12

File tree

1 file changed

+90
-2
lines changed

1 file changed

+90
-2
lines changed

src/3DVL/LASO.md

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,100 @@ _3daffordance = torch.sigmoid(_3daffordance)
677677

678678
## 损失函数
679679

680-
### Focal Loss
680+
### HM_Loss(Hybrid Mask Loss
681681

682+
在 LASO 数据集中,模型需要根据自然语言问题识别点云中最相关的功能区域(如 grasping area, opening area 等),而 HM_Loss 是 PointRefer 模型的监督信号,它结合了:
682683

684+
- Focal Loss :用于缓解类别不平衡问题;
683685

684-
### Dice Loss
686+
- Dice Loss :用于衡量预测掩码与真实标签之间的空间重合度;
685687

688+
最终 loss = CELoss + DiceLoss,让模型同时关注逐点分类精度和整体区域匹配。
689+
690+
```python
691+
import torch
692+
import torch.nn as nn
693+
import torch.nn.functional as F
694+
695+
696+
class HM_Loss(nn.Module):
697+
def __init__(self):
698+
"""
699+
Hybrid Mask Loss 实现:
700+
- BCE-Focal Loss(加权交叉熵)
701+
- Dice Loss(衡量预测掩码与 GT 的重合度)
702+
703+
公式来自论文 Section 4.2,用于语言引导下的功能区域分割。
704+
"""
705+
super(HM_Loss, self).__init__()
706+
# 设置 Focal Loss 参数
707+
self.gamma = 2 # 聚焦参数,放大难分类样本影响
708+
self.alpha = 0.25 # 平衡因子,强调正类(前景点)loss
709+
710+
def forward(self, pred, target):
711+
"""
712+
输入:
713+
pred: 模型输出的原始 logit 或经过 sigmoid 的概率值;
714+
形状为 [B, N]
715+
target: ground truth 掩码(soft mask),形状也为 [B, N]
716+
717+
返回:
718+
total_loss: CELoss + DiceLoss 的加权和
719+
"""
720+
721+
# Step 1: 构建 Focal Loss 权重项
722+
# temp1:负类 loss(背景点)
723+
# temp2:正类 loss(目标功能区域)
724+
# 1e-6 的加入是为了让 log 计算保持稳定,尤其是在预测值接近极端值(0 或 1)时
725+
temp1 = -(1 - self.alpha) * torch.mul(
726+
pred ** self.gamma,
727+
torch.mul(1 - target, torch.log(1 - pred + 1e-6))
728+
)
729+
temp2 = -self.alpha * torch.mul(
730+
(1 - pred) ** self.gamma,
731+
torch.mul(target, torch.log(pred + 1e-6))
732+
)
733+
734+
# 将两个方向的 loss 合并,并取 batch 和点维度的平均
735+
temp = temp1 + temp2
736+
CELoss = torch.sum(torch.mean(temp, dim=(0, 1)))
737+
738+
# Step 2: 计算正类 Dice Loss(预测与 Ground Truth 的交集 / 并集)
739+
intersection_positive = torch.sum(pred * target, dim=1)
740+
cardinality_positive = torch.sum(torch.abs(pred) + torch.abs(target), dim=1)
741+
dice_positive = (intersection_positive + 1e-6) / (cardinality_positive + 1e-6)
742+
743+
# Step 3: 计算负类 Dice Loss(非目标区域匹配度)
744+
intersection_negative = torch.sum((1 - pred) * (1 - target), dim=1)
745+
cardinality_negative = torch.sum(2 - torch.abs(pred) - torch.abs(target), dim=1)
746+
dice_negative = (intersection_negative + 1e-6) / (cardinality_negative + 1e-6)
747+
748+
# Step 4: 构建 Dice Loss,形式为 1 - Dice Score
749+
# 使用了一个偏置项 1.5(可能是经验设定)
750+
temp3 = torch.mean(1.5 - dice_positive - dice_negative, dim=0)
751+
DICELoss = torch.sum(temp3)
752+
753+
# Step 5: 总损失 = 分类误差 + 区域匹配误差
754+
return CELoss + 1.0 * DICELoss
755+
```
756+
757+
在论文 Section 4.2 中提到:
758+
759+
> “We solely employ Dice loss and Binary Cross-Entropy (BCE) loss to guide the segmentation mask prediction.”
760+
761+
虽然这里用的是 Focal Loss + Dice Loss 的组合形式,但它本质上是 BCE + Dice 的改进版,具有以下优势:
762+
763+
- Focal Loss: 抑制 easy examples,放大 hard examples,防止忽略小区域
764+
765+
- Dice Loss: 关注整体掩码匹配度,提升边界识别能力
766+
767+
两者结合可以:
768+
769+
- 缓解类别极度不平衡问题;
770+
771+
- 提高模型对语言指令下功能区域的理解能力;
772+
773+
- 更好地应对 LASO 中的语言引导 + soft mask 场景;
686774

687775
## 训练
688776

0 commit comments

Comments
 (0)