下面是具体的参数:
1. pos_weight:
-
处理样本不均衡问题
torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
- 其中* pos_weight (Tensor*, *optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
- pos_weight里是一个tensor列表,需要和标签个数相同,比如现在有一个多标签分类,类别有200个,那么 pos_weight 就是为每个类别赋予的权重值,长度为200,官方给出的例子是:
target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones([64]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
- 如果现在是二分类,只需要将正样本loss的权重写上即可,比如我们有正负两类样本,正样本数量为100个,负样本为400个,我们想要对正负样本的loss进行加权处理,将正样本的loss权重放大4倍,通过这样的方式缓解样本不均衡问题:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
-- pos_weight (Tensor, optional): a weight of positive examples.
--Must be a vector with length equal to the number of classes.