BCEWithLogitsLoss参数pos_weight样本不均衡问题

下面是具体的参数:

1. pos_weight:
  • 处理样本不均衡问题
    torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)


  • 其中* pos_weight (Tensor*, *optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
  • pos_weight里是一个tensor列表,需要和标签个数相同,比如现在有一个多标签分类,类别有200个,那么 pos_weight 就是为每个类别赋予的权重值,长度为200,官方给出的例子是:
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(1.5))
  • 如果现在是二分类,只需要将正样本loss的权重写上即可,比如我们有正负两类样本,正样本数量为100个,负样本为400个,我们想要对正负样本的loss进行加权处理,将正样本的loss权重放大4倍,通过这样的方式缓解样本不均衡问题:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))

-- pos_weight (Tensor, optional): a weight of positive examples.
--Must be a vector with length equal to the number of classes.

参考:
BCEWithLogitsLoss样本不均衡的处理

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容