损失函数 的 硬截断 和 平滑衰减flyfish在逐样本损失计算完成、取平均之前对损失过高的样本做权重压制不删除样本只削弱它们对梯度的贡献属于软降权——既保留了样本的监督信号又避免极端难样本/疑似错标样本带偏整个模型。损失硬截断损失硬截断是给单样本损失设置一个上限超过这个阈值的损失直接按阈值计算。相当于一刀切超过上限的样本梯度不再放大。代码实现classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma2,alphaNone,smoothing0.0,num_classes2,max_lossNone): :param max_loss: 单样本损失上限None表示不开启截断设置数值后单样本损失不会超过该值 super().__init__()self.gammagamma self.alphatorch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothingsmoothing self.num_classesnum_classes self.max_lossmax_loss# 损失截断阈值defforward(self,inputs,targets):targets_one_hottorch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targetstargets_one_hot*(1-self.smoothing)self.smoothing/self.num_classes log_probstorch.nn.functional.log_softmax(inputs,dim1)probstorch.exp(log_probs)p_t(probs*targets_one_hot).sum(dim1,keepdimTrue)focal_weight(1-p_t)**self.gamma ce_loss(-soft_targets*log_probs).sum(dim1)lossfocal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim1)lossloss*alpha_t# 损失截断 ifself.max_lossisnotNone:losstorch.clamp(loss,maxself.max_loss)returnloss.mean()使用方式在训练函数里初始化损失时多加一个max_loss参数即可# 示例单样本损失最高不超过2.0超过的全部按2.0计算criterionFocalLossWithSmoothing(gammaFOCAL_GAMMA,alphaFOCAL_ALPHA,smoothingLABEL_SMOOTHING,num_classesNUM_CLASSES,max_loss2.0# 开启截断阈值可按需调整)平滑衰减降权硬截断是一刀切损失超过阈值直接砍平损失值瞬间不再增长像台阶一样突变平滑衰减是越涨越慢损失低于阈值时正常计算超过阈值后还能继续涨但增长速度会越来越慢过渡是顺滑的曲线没有突变台阶。它的目的既保留损失越高、权重越大的相对顺序又不让极端高损失样本无限放大梯度带偏模型同时保证训练过程梯度平稳不会出现跳变。代码实现 只需要把截断部分替换成平滑衰减逻辑即可classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma2,alphaNone,smoothing0.0,num_classes3,loss_threshold1.8):super().__init__()self.gammagamma self.alphatorch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothingsmoothing self.num_classesnum_classes self.loss_thresholdloss_threshold# 平滑衰减阈值defforward(self,inputs,targets):targets_one_hottorch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targetstargets_one_hot*(1-self.smoothing)self.smoothing/self.num_classes log_probstorch.nn.functional.log_softmax(inputs,dim1)probstorch.exp(log_probs)p_t(probs*targets_one_hot).sum(dim1,keepdimTrue)focal_weight(1-p_t)**self.gamma ce_loss(-soft_targets*log_probs).sum(dim1)lossfocal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim1)lossloss*alpha_t# 平滑衰减降权压制极端高损失样本ifself.loss_thresholdisnotNone:high_loss_masklossself.loss_threshold loss[high_loss_mask]self.loss_thresholdtorch.log(1loss[high_loss_mask]-self.loss_threshold)returnloss.mean()假设设置阈值 1.5看不同原始损失对应的处理结果原始单样本损失硬截断后损失变化特点1.0正常样本1.0低于阈值完全不变1.4较难样本1.4低于阈值完全不变1.5阈值点1.5刚好等于阈值1.6难样本1.5超过一点点直接被砍成1.5瞬间停止增长3.0极难/错标样本1.5不管多高全砍成1.5和1.6的样本权重完全一样硬截断的问题阈值点处损失突变梯度也会突变训练过程容易出现震荡所有超过阈值的样本损失都一样丢失了难分程度的差异信息——3.0的极难样本和1.6的轻微难样本对模型的贡献变得完全相同有点矫枉过正。平滑衰减的逻辑两段式 对数压缩代码里用的是阈值以下正常计算阈值以上对数压缩的两段式策略公式是处理后损失{原始损失原始损失≤阈值阈值log⁡(1原始损失−阈值)原始损失阈值 \text{处理后损失} \begin{cases} \text{原始损失} \text{原始损失} \le 阈值 \\ 阈值 \log(1 \text{原始损失} - 阈值) \text{原始损失} 阈值 \end{cases}处理后损失{原始损失阈值log(1原始损失−阈值)​原始损失≤阈值原始损失阈值​为什么用 log对数函数对数函数有两个完美匹配需求的特性单调递增原始损失越大处理后的损失也一定越大不会改变谁更难、谁损失更高的排序样本的相对权重关系保留了增速递减x 越大log(x) 涨得越慢。原始损失越高压缩力度越强正好符合极端样本降权更多的需求。直观对比效果还是设阈值 1.5算一组真实数值一眼就能看出区别原始单样本损失硬截断后平滑衰减后直观感受1.01.01.00低于阈值两者完全一样1.41.41.40低于阈值两者完全一样1.51.51.50阈值点两者对齐1.61.51.595只超了一点点压缩很轻微几乎和原值差不多2.01.51.693超了0.5增长明显放缓不再是直线涨3.01.51.946超了1.5涨幅被大幅压缩不会涨到3.05.01.52.208超了3.5增速进一步变慢和3.0的差距被缩小可以明显看到刚超过阈值时损失几乎不受影响过渡非常顺滑损失越高被压缩得越厉害但始终保持越高越重的排序不会像硬截断那样所有高损失全变成同一个值。对应代码loss[high_loss_mask]self.loss_thresholdtorch.log(1loss[high_loss_mask]-self.loss_threshold)拆解开loss[high_loss_mask] - self.loss_threshold算出损失超出阈值的部分增量1 增量加1保证对数的输入大于0避免出现负数报错torch.log(...)对超出的增量做对数压缩让增量涨得变慢self.loss_threshold 压缩后的增量把基准阈值加回来保证阈值点处数值连续、没有台阶。什么时候用硬截断什么时候用平滑衰减方案场景特点硬截断确定有大量标注错误想直接屏蔽极端错标的影响简单粗暴可控性强调试方便平滑衰减样本大多是标注正确的难样本比如小目标、低对比度只想削弱、不想完全屏蔽更温和梯度平稳训练更稳定保留难样本的相对差异信息