BCE loss和 CE理解

BCE loss和 CE理解BCElosspytorch官网链接BCEloss:BinaryCrossEntropyLosspytorch中调用如下。设置weight,使得不同类别的损失权值不同。其中x是预测值,取值范围(0,1),target是标签,取值为0或1.在Retinanet的分类部分最后一层的激活函数用的是sigmoid,损失函数是BCEloss.BCEloss可以对单个类别进行求损失,配合sigmoid(每个类别单独求概率,不同类别间互不影响)。相当于是把每个类别都看成了二分类的问题,为后面使

大家好,欢迎来到IT知识分享网。

1. BCE loss:Binary Cross Entropy Loss

BCE loss pytorch官网链接

1.1 解释

pytorch中调用如下。设置weight,使得不同类别的损失权值不同。
在这里插入图片描述
其中x是预测值,取值范围(0,1), target是标签,取值为0或1.

在Retinanet的分类部分最后一层的激活函数用的是sigmoid,损失函数是BCE loss.BCE loss可以对单个类别进行求损失,配合sigmoid(每个类别单独求概率,不同类别间互不影响)。相当于是把每个类别都看成了二分类的问题,为后面使用Focal Loss做准备。
Focal Loss是针对目标检测任务中正负样本不均横提出的,通过使用Focal Loss,使得易于分类的背景类别的损失相对于不适用会大幅降低。

1.2 实现代码

import torch
from torch.nn import BCELoss
import torch.nn as nn

y = torch.randn((2,3))
target = torch.tensor([[0,0,1],[1,0,0]],dtype=torch.float32)

m = nn.Sigmoid()# BEC的输入变量的取值范围[0,1]
y = m(y)
print('y',y)
print('target',target)

loss_func = BCELoss()
loss = loss_func(y,target)

def compute_bce(y,target):
    b = y.shape[0]
    loss = 0
    for i in range(b): # 多少个实例
        for j in range(len(y[i])):# 每个实例有多少个类别
            if target[i][j] == 0.:
                temp_loss =  torch.log(1-y[i][j])
            else:
                temp_loss = torch.log(y[i][j])
            loss -= temp_loss
    return loss/(b*y.shape[1])
loss_ = compute_bce(y,target)

print('loss',loss)
print('loss',loss_)

输出

y tensor([[0.3410, 0.1810, 0.1073],
        [0.7485, 0.6879, 0.2198]])
target tensor([[0., 0., 1.],
        [1., 0., 0.]])
loss tensor(0.7586)
loss compute tensor(0.7586)

1.3 求解过程

BCE Loss默认求误差的方式是取平均
预测值y,标签是target

y tensor([[0.3410, 0.1810, 0.1073],
        [0.7485, 0.6879, 0.2198]])
target tensor([[0., 0., 1.],
        [1., 0., 0.]])

y的形式可以看成是三分类,共有两个样本
根据公式 b c e _ l o s s ( y , t a q g e t ) = − 1 N Σ i N [ t a r g e t [ i ] ∗ l o g ( y [ i ] ) + ( 1 − t a r g e t [ i ] ) ∗ l o g ( 1 − y [ i ] ) ] bce\_loss(y,taqget) = -\frac{1}{N}\Sigma_{i}^{N}[target[i]*log(y[i]) + (1-target[i])*log(1-y[i])] bce_loss(y,taqget)=N1ΣiN[target[i]log(y[i])+(1target[i])log(1y[i])]
这个是求均值的方式。
[0.3410, 0.1810, 0.1073]对应的标签为[0., 0., 1.]
对于第一个实例,它的损失为
− 1 3 [ [ 0 ∗ l o g ( 0.3410 ) + ( 1 − 0 ) ∗ l o g ( 1 − 0.3410 ) ] + [ 0 ∗ l o g ( 0.1810 ) + ( 1 − 0 ) ∗ l o g ( 1 − 0.1810 ) ] + [ 1 ∗ l o g ( 0.1073 ) + ( 1 − 1 ) ∗ l o g ( 1 − 0.1810 ) ] ] -\frac{1}{3}[[0*log(0.3410)+(1-0)*log(1-0.3410)] + [0*log(0.1810)+(1-0)*log(1-0.1810)] + [1*log(0.1073)+(1-1)*log(1-0.1810)] ] 31[[0log(0.3410)+(10)log(10.3410)]+[0log(0.1810)+(10)log(10.1810)]+[1log(0.1073)+(11)log(10.1810)]],
三个类别单独求损失,然后加起来取平均

BCE Loss 类似的是BCEWithLogiticsLoss,不同之处在于对预测值加了sigmoid变化
在这里插入图片描述

2. CE Loss

CE: Cross Entropy

2.1 解释

pytorch 官网介绍
设置weight,使得不同类别的损失权值不同。
输入变量:预测值inputtarget.
处理过程:先对input求softmax,然后对每个概率取对数,再求NLLLoss(negative log likelihood loss,负对数似然损失)

在这里插入图片描述

2.2 代码实现

import torch
import torch.nn as nn
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
print('input',input)
print('target',target)
print('output loss',output)

def compute_ce(input,target):
    input = nn.Softmax(dim=-1)(input)
    b = input.shape[0]
    loss = 0
    for i in range(b):
        # input[i] 第i个样本对应的预测概率
        # target[i]样本标签
        # 取出input[i]中索引为target[i]的概率值
        
        temp_loss = -1*torch.log(input[i][target[i]])
        loss += temp_loss
    loss = loss/b
    return loss
loss_compute = compute_ce(input,target)
print('compute_loss',loss_compute)


输出

input tensor([[ 5.1491e-01,  4.8284e-01, -5.4881e-01,  3.3695e-01, -9.2449e-01],
        [ 7.5421e-01, -1.3219e+00,  8.3139e-02, -9.4772e-02,  9.4742e-01],
        [-3.5083e-01,  1.1744e+00, -1.0697e+00,  1.9165e-03, -2.4743e+00]],
       requires_grad=True)
target tensor([2, 4, 4])
output tensor(2.4776, grad_fn=<NllLossBackward>)
compute loss tensor(2.4776, grad_fn=<DivBackward0>)

2.3 过程

  1. 先对input取softmax
  2. 遍历单个样本,比如第一个的损失为-log(input[0][2]),第二个-log(input[1][4]),第三个-log(input[2][4])
  3. 对当个样本的损失求和,取平均

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/12623.html

(0)

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

关注微信