今天看代码,学习了一种图像分割领域使用的损失函数(目标检测领域也会用,但这里只介绍图像分割):OhemCELoss,这里记录一下。
函数名 Online hard example mining cross-entropy loss 的缩写。其中 Online hard example mining 的意思是,在训练过程中关注 hard example ,对其施加更高权重的一种训练策略。cross-entropy loss 就是普通的交叉熵损失函数。
cross-entropy loss
首先回顾一下多分类问题的 cross-entropy loss 的公式:
其中 为预测值,
为真实类别,大括号里面计算了样本在真实类别上的概率,这个概率越大(越接近 1),其负对数就越接近 0;反之,这个概率越小,其负对数就越接近正无穷。
pytorch 下的接口为:
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
-
weight
(Tensor, optional) 表示每个类别设置的权重,用在类别不平衡的数据集上 -
size_average
(bool, optional) ,True 计算整个 batch 上的平均值作为 loss,False 计算batch loss 之和。 -
ignore_index
(int, optional) 某个 ground truth 值被忽略,即当 ground truth 为它时,损失直接设为 0。 -
reduce
(bool, optional) 默认情况下,计算的 loss 为一个 batch 中每个元素的 loss 之和或者平均,设置为 False 后,返回每个元素的 loss。 -
reduction
(string, optional) 设置不同的输出格式,可以设为'none' | 'mean' | 'sum'
,由于它的存在,size_average
和reduce
将被淘汰(size_average=True
相当于 'mean',size_average=False
相当于 'sum',reduce=False
相当于 'none')。
代码示例如下:
loss = nn.CrossEntropyLoss()
x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
out = loss(x, label)
为了深入理解原理,也可以自己实现:
x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
exp_x = torch.exp(x)
prob_x = exp_x/exp_x.sum(dim=1, keepdim=True)
prob_x = torch.gather(prob_x, dim=1, index=label.unsqueeze(1))
# reduction = 'none'
print(-torch.log(prob_x))
# reduction = 'sum'
print(-torch.log(prob_x).sum())
# reduction = 'mean'
print(-torch.log(prob_x).mean())
OhemCELoss
先贴代码
class OhemCELoss(nn.Module):
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
self.n_min = n_min
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
def forward(self, logits, labels):
N, C, H, W = logits.size()
loss = self.criteria(logits, labels).view(-1)
loss, _ = torch.sort(loss, descending=True)
if loss[self.n_min] > self.thresh:
loss = loss[loss>self.thresh]
else:
loss = loss[:self.n_min]
return torch.mean(loss)
该损失函数有几个超参数:thresh
,n_min
,其中 thresh
表示的是,损失函数大于多少的时候,会被用来做反向传播。n_min
表示的是,在一个 batch 中,最少需要考虑多少个样本。
需要注意一点,参数 thresh
是概率,即 小于这个概率的预测值会参与计算损失。
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float))
将概率转化为其对应的 loss 。
代码中的 logits
维度为 ,
labels
维度为 。
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
设置 reduction 为 none,保留每个元素的损失,返回的维度为 。
loss = self.criteria(logits, labels).view(-1)
将预测的损失拉平为一个长向量,每个元素为一个 pixel 的损失。
loss, _ = torch.sort(loss, descending=True)
将长向量中每个 pixel 的损失按从大到小排序。
if loss[self.n_min] > self.thresh:
最少考虑 n_min
个损失最大的 pixel,如果前 n_min
个损失中最小的那个的损失仍然大于设定的阈值,那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]
。
否则,计算前 n_min
个损失:loss = loss[:self.n_min]
。
最后,求这些 hard example 的损失的均值作为最终损失:torch.mean(loss)
。
总结
在图像分割问题中,以每个 pixel 的损失为最小单元,而不是 batch 中每张图片。因此排序时需要把 batch 中所有 pixel 拉成一个长向量,再取其中大于阈值的 pixel 作为 hard example。同时,n_min
的设置保证了每个 batch 中都有至少 n_min
个pixel 参与训练,从而一定程度巩固了训练结果,让前向传播不至于空耗。