pytorch二分类交叉熵是什么,怎样实现
Admin 2022-07-26 群英技术资讯 340 次浏览
通常,由于类别不均衡,需要使用weighted cross entropy loss平衡。
def inverse_freq(label): """ 输入label [N,1,H,W],1是channel数目 """ den = label.sum() # 0 _,_,h,w= label.shape num = h*w alpha = den/num # 0 return torch.tensor([alpha, 1-alpha]).cuda() # train ... loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))
补充:Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)
在Pytorch中的交叉熵函数的血泪史要从nn.CrossEntropyLoss()这个损失函数开始讲起。
从表面意义上看,这个函数好像是普通的交叉熵函数,但是如果你看过一些Pytorch的资料,会告诉你这个函数其实是softmax()和交叉熵的结合体。
然而如果去官方看这个函数的定义你会发现是这样子的:
哇,竟然是nn.LogSoftmax()和nn.NLLLoss()的结合体,这俩都是什么玩意儿啊。再看看你会发现甚至还有一个损失叫nn.Softmax()以及一个叫nn.nn.BCELoss()。
我们来探究下这几个损失到底有何种关系。
首先nn.Softmax()官网的定义是这样的:
嗯...就是我们认识的那个softmax。那nn.LogSoftmax()的定义也很直观了:
果不其然就是Softmax取了个log。可以写个代码测试一下:
import torch import torch.nn as nn a = torch.Tensor([1,2,3]) #定义Softmax softmax = nn.Softmax() sm_a = softmax=nn.Softmax() print(sm) #输出:tensor([0.0900, 0.2447, 0.6652]) #定义LogSoftmax logsoftmax = nn.LogSoftmax() lsm_a = logsoftmax(a) print(lsm_a) #输出tensor([-2.4076, -1.4076, -0.4076]),其中ln(0.0900)=-2.4076
上面说过nn.CrossEntropy()是nn.LogSoftmax()和nn.NLLLoss的结合,nn.NLLLoss官网给的定义是这样的:
The negative log likelihood loss. It is useful to train a classification problem with C classes
负对数似然损失 ,看起来好像有点晦涩难懂,写个代码测试一下:
import torch import torch.nn a = torch.Tensor([[1,2,3]]) nll = nn.NLLLoss() target1 = torch.Tensor([0]).long() target2 = torch.Tensor([1]).long() target3 = torch.Tensor([2]).long() #测试 n1 = nll(a,target1) #输出:tensor(-1.) n2 = nll(a,target2) #输出:tensor(-2.) n3 = nll(a,target3) #输出:tensor(-3.)
看起来nn.NLLLoss做的事情是取出a中对应target位置的值并取负号,比如target1=0,就取a中index=0位置上的值再取负号为-1,那这样做有什么意义呢,要结合nn.CrossEntropy往下看。
看下官网给的nn.CrossEntropy()的表达式:
看起来应该是softmax之后取了个对数,写个简单代码测试一下:
import torch import torch.nn as nn a = torch.Tensor([[1,2,3]]) target = torch.Tensor([2]).long() logsoftmax = nn.LogSoftmax() ce = nn.CrossEntropyLoss() nll = nn.NLLLoss() #测试CrossEntropyLoss cel = ce(a,target) print(cel) #输出:tensor(0.4076) #测试LogSoftmax+NLLLoss lsm_a = logsoftmax(a) nll_lsm_a = nll(lsm_a,target) #输出tensor(0.4076)
看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果。为什么这样呢,回想下交叉熵的表达式:
其中y是label,x是prediction的结果,所以其实交叉熵损失就是负的target对应位置的输出结果x再取-log。这个计算过程刚好就是先LogSoftmax()再NLLLoss()。
所以我认为nn.CrossEntropyLoss其实应该叫做softmaxloss更为合理一些,这样就不会误解了。
你以为这就完了吗,其实并没有。还有一类损失叫做BCELoss,写全了的话就是Binary Cross Entropy Loss,就是交叉熵应用于二分类时候的特殊形式,一般都和sigmoid一起用,表达式就是二分类交叉熵:
直觉上和多酚类交叉熵的区别在于,不仅考虑了的样本,也考虑了的样本的损失。
nn.LogSoftmax是在softmax的基础上取自然对数
nn.NLLLoss是负的似然对数损失,但Pytorch的实现就是把对应target上的数取出来再加个负号,要在CrossEntropy中结合LogSoftmax来用
BCELoss是二分类的交叉熵损失,Pytorch实现中和多分类有区别
Pytorch是个深坑,让我们一起扎根使用手册,结合实践踏平这些坑吧暴风哭泣。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章介绍了python自动化测试之破解图文验证码的解决方案,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
selenium的几种定位方法中,大杀器之一就是xpath方法,学会它,你将无所不能.本文就带大家详细了解一下这个大杀器,文中有非常详细的介绍,需要的朋友可以参考下
Python如何实现读取远程页面并写入本地页面,废话不多说,直接看代码
这篇文章主要介绍了如何生成对角矩阵 numpy.diag,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
Python提供了两个级别的访问网络服务。 在低级别,可以访问底层操作系统中的基本套接字支持,这允许您实现面向连接和无连接协议的客户端和服务器。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008