2020-08-31 09:52

【Trick】标签平滑

引用自 https://blog.csdn.net/sinat_36618660/article/details/100166957

why

在深度学习样本训练的过程中,我们采用one-hot标签去进行计算交叉熵损失时,只考虑到训练样本中正确的标签位置(one-hot标签为1的位置)的损失,而忽略了错误标签位置(one-hot标签为0的位置)的损失。这样一来,模型可以在训练集上拟合的很好,但由于其他错误标签位置的损失没有计算,导致预测的时候,预测错误的概率增大。为了解决这一问题,标签平滑的正则化方法便应运而生。

how

先举个没有标签平滑计算的例子
Example1:假设有一批样本,样本类别总数为5,从中取出一个样本,得到该样本的one-hot化后的标签为0,0,0,1,0,假设我们已经得到了该样本进行softmax的概率矩阵pp,即
p=[p1,p2,p3,p4,p5]=[0.1,0.1,0.1,0.36,0.34]
则我们可以求得当前单个样本的lossloss,即
loss=−(0∗log0.1+0∗log0.1+0∗log0.1+1∗log0.36+0∗log0.34)
计算结果为:
loss=−log0.36=1.47

再举一个标签平滑的例子
Example2Example2:假设还是上面那批样本,样本类别总数仍为5,我们还是取出刚才的那个样本,得到该样本的one-hot化后的标签为0,0,0,1,0,仍假设我们已经得到了该样本进行softmax的概率矩阵pp,即
p=[p1,p2,p3,p4,p5]=[0.1,0.1,0.1,0.36,0.34],对于进行标签平滑该怎么做呢?我们先设一个平滑因子为ϵ=0.1,进行如下平滑,
y1=(1−ϵ)∗[0,0,0,1,0]=[0,0,0,0.9,0]
y2=ϵ∗[1,1,1,1,1]=[0.1,0.1,0.1,0.1,0.1]
y=y1+y2=[0.1,0.1,0.1,1.0,0.1]
yy就是我们经过平滑操作后得到的标签,接着我们就可以求平滑后该样本的交叉熵损失了
loss=−y∗logp=−[0.1,0.1,0.1,1.0,0.1]∗log([0.1,0.1,0.1,0.36,0.34])
计算结果为:
loss=2.63

code

def cross_entropy_loss(preds, target, reduction):
    logp = F.log_softmax(preds, dim=1)
    loss = torch.sum(-logp * target, dim=1)
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        raise ValueError(
            '`reduction` must be one of \'none\', \'mean\', or \'sum\'.')
        
def onehot_encoding(labels, n_classes):
    return torch.zeros(labels.size(0), n_classes).to(labels.device).scatter_(
        dim=1, index=labels.view(-1, 1), value=1)
    
def label_smoothing(preds, targets,epsilon=0.1):
    #preds为网络最后一层输出的logits
    #targets为未one-hot的真实标签
    n_classes = preds.size(1)
    device = preds.device
    
    onehot = onehot_encoding(targets, n_classes).float().to(device)
    targets = onehot * (1 - epsilon) + torch.ones_like(onehot).to(
        device) * epsilon / n_classes
    loss = cross_entropy_loss(preds, targets, reduction="mean")
    return loss

添加新评论

icon_question.gificon_razz.gificon_sad.gificon_evil.gificon_exclaim.gificon_smile.gificon_redface.gificon_biggrin.gificon_surprised.gificon_eek.gificon_confused.gificon_cool.gificon_lol.gificon_mad.gificon_twisted.gificon_rolleyes.gificon_wink.gificon_idea.gificon_arrow.gificon_neutral.gificon_cry.gificon_mrgreen.gif

captcha
请输入验证码