图像分割 OhemCELoss

今天看代码,学习了一种图像分割领域使用的损失函数(目标检测领域也会用,但这里只介绍图像分割):OhemCELoss,这里记录一下。

函数名 Online hard example mining cross-entropy loss 的缩写。其中 Online hard example mining 的意思是,在训练过程中关注 hard example ,对其施加更高权重的一种训练策略。cross-entropy loss 就是普通的交叉熵损失函数。

cross-entropy loss

首先回顾一下多分类问题的 cross-entropy loss 的公式:
\mathrm{loss}(x, cls) = -\mathrm{log}\left(\frac{\mathrm{exp}(x[cls])}{\sum_j\mathrm{exp}(x[j])}\right)
其中 x 为预测值,cls 为真实类别,大括号里面计算了样本在真实类别上的概率,这个概率越大(越接近 1),其负对数就越接近 0;反之,这个概率越小,其负对数就越接近正无穷。

pytorch 下的接口为:
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

  • weight (Tensor, optional) 表示每个类别设置的权重,用在类别不平衡的数据集上
  • size_average (bool, optional) ,True 计算整个 batch 上的平均值作为 loss,False 计算batch loss 之和。
  • ignore_index (int, optional) 某个 ground truth 值被忽略,即当 ground truth 为它时,损失直接设为 0。
  • reduce (bool, optional) 默认情况下,计算的 loss 为一个 batch 中每个元素的 loss 之和或者平均,设置为 False 后,返回每个元素的 loss。
  • reduction (string, optional) 设置不同的输出格式,可以设为 'none' | 'mean' | 'sum',由于它的存在,size_averagereduce 将被淘汰(size_average=True 相当于 'mean',size_average=False 相当于 'sum',reduce=False 相当于 'none')。

代码示例如下:

loss = nn.CrossEntropyLoss()
x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
out = loss(x, label)

为了深入理解原理,也可以自己实现:

x = torch.randn(3, 5)
label = torch.empty(3, dtype=torch.long).random_(5)
exp_x = torch.exp(x)
prob_x = exp_x/exp_x.sum(dim=1, keepdim=True)
prob_x = torch.gather(prob_x, dim=1, index=label.unsqueeze(1))
# reduction = 'none'
print(-torch.log(prob_x))
# reduction = 'sum'
print(-torch.log(prob_x).sum())
# reduction = 'mean'
print(-torch.log(prob_x).mean())

OhemCELoss

先贴代码

class OhemCELoss(nn.Module):
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)
        if loss[self.n_min] > self.thresh:
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

该损失函数有几个超参数:threshn_min,其中 thresh 表示的是,损失函数大于多少的时候,会被用来做反向传播。n_min 表示的是,在一个 batch 中,最少需要考虑多少个样本。
需要注意一点,参数 thresh 是概率,即 小于这个概率的预测值会参与计算损失。

self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)) 将概率转化为其对应的 loss 。

代码中的 logits 维度为 N\times C\times H\times Wlabels 维度为 N\times H \times W

self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 设置 reduction 为 none,保留每个元素的损失,返回的维度为 N\times H\times W

loss = self.criteria(logits, labels).view(-1) 将预测的损失拉平为一个长向量,每个元素为一个 pixel 的损失。

loss, _ = torch.sort(loss, descending=True) 将长向量中每个 pixel 的损失按从大到小排序。

if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]

否则,计算前 n_min 个损失:loss = loss[:self.n_min]

最后,求这些 hard example 的损失的均值作为最终损失:torch.mean(loss)

总结

在图像分割问题中,以每个 pixel 的损失为最小单元,而不是 batch 中每张图片。因此排序时需要把 batch 中所有 pixel 拉成一个长向量,再取其中大于阈值的 pixel 作为 hard example。同时,n_min 的设置保证了每个 batch 中都有至少 n_min 个pixel 参与训练,从而一定程度巩固了训练结果,让前向传播不至于空耗。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,427评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,551评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,747评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,939评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,955评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,737评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,448评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,352评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,834评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,992评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,133评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,815评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,477评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,022评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,147评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,398评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,077评论 2 355

推荐阅读更多精彩内容