KL散度实现计算过程的相关问题怎样解决
Admin 2022-07-26 群英技术资讯 416 次浏览
偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下
一篇关于KL散度、JS散度以及交叉熵对比的文章
KL散度( Kullback-CLeibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:
可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。
两个离散分布分布分别为 P 和 Q
P 的分布为:{1,1,2,2,3}
Q 的分布为:{1,1,1,1,1,2,3,3,3,3}
我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。
当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5
同理,
当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
至此,就计算完成了两个离散变量分布的KL散度。
pytorch中有用于计算kl散度的函数 kl_div
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
计算 D (p||q)
与手算结果相同
(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)
注意:
1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和
好别扭的用法,不知道为啥官方把它设计成这样
补充:pytorch 的KL divergence的实现
import torch.nn.functional as F # p_logit: [batch, class_num] # q_logit: [batch, class_num] def kl_categorical(p_logit, q_logit): p = F.softmax(p_logit, dim=-1) _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), 1) return torch.mean(_kl)
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
本篇文章给大家带来了关于Python的相关知识,其中主要介绍了垃圾回收机制中的引用计数的相关内容,如果我们在Python中有一个指向某个对象的指针,那就是对该对象的引用,下面一起来看一下,希望对大家有帮助。
这篇文章主要介绍了Python网络编程之ZeroMQ知识总结,文中有非常详细的代码示例,对正在学习python的小伙伴们有非常好的帮助,需要的朋友可以参考下
这篇文章主要介绍了Python 制作子弹图,众所周知,Python 的应用是非常广泛的,今天我们就通过 matplotlib 库学习下如何制作精美的子弹图,需要的朋友可以参考一下
python@运算符是什么意思?怎样使用?对于刚接触Python的朋友,可能对@运算符不是很了解,因此这篇文章就给大家介绍一下python@运算符的内容,感兴趣的朋友就继续往下看吧。
这篇文章主要为大家介绍了python基础中的文件对象,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助<BR>
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008