BatchNormalization底层原理如何理解,Bn层的好处是什么
Admin 2022-09-03 群英技术资讯 298 次浏览
Batch Normalization是神经网络中常用的层,解决了很多深度学习中遇到的问题,我们一起来学习一哈。
Batch Normalization是由google提出的一种训练优化方法。参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift。
Batch Normalization的名称为批标准化,它的功能是使得输入的X数据符合同一分布,从而使得训练更加简单、快速。
一般来讲,Batch Normalization会放在卷积层后面,即卷积 + 标准化 + 激活函数。
其计算过程可以简单归纳为以下3点:
1、求数据均值。
2、求数据方差。
3、数据进行标准化。
Batch Normalization的计算公式主要看如下这幅图:
这个公式一定要静下心来看,整个公式可以分为四行:
1、对输入进来的数据X进行均值求取。
2、利用输入进来的数据X减去第一步得到的均值,然后求平方和,获得输入X的方差。
3、利用输入X、第一步获得的均值和第二步获得的方差对数据进行归一化,即利用X减去均值,然后除上方差开根号。方差开根号前需要添加上一个极小值。
4、引入γ和β变量,对输入进来的数据进行缩放和平移。利用γ和β两个参数,让我们的网络可以学习恢复出原始网络所要学习的特征分布。
前三步是标准化工序,最后一步是反标准化工序。
1、加速网络的收敛速度。在神经网络中,存在内部协变量偏移的现象,如果每层的数据分布不同的话,会导致非常难收敛,如果把每层的数据都在转换在均值为零,方差为1的状态下,这样每层数据的分布都是一样的,训练会比较容易收敛。
2、防止梯度爆炸和梯度消失。对于梯度消失而言,以Sigmoid函数为例,它会使得输出在[0,1]之间,实际上当x到了一定的大小,sigmoid激活函数的梯度值就变得非常小,不易训练。归一化数据的话,就能让梯度维持在比较大的值和变化率;
对于梯度爆炸而言,在方向传播的过程中,每一层的梯度都是由上一层的梯度乘以本层的数据得到。如果归一化的话,数据均值都在0附近,很显然,每一层的梯度不会产生爆炸的情况。
3、防止过拟合。在网络的训练中,Bn使得一个minibatch中所有样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合。
Bn层在进行前三步后,会引入γ和β变量,对输入进来的数据进行缩放和平移。
γ和β变量是网络参数,是可学习的。
引入γ和β变量进行缩放平移可以使得神经网络有自适应的能力,在标准化效果好时,尽量不抵消标准化的作用,而在标准化效果不好时,尽量去抵消一部分标准化的效果,相当于让神经网络学会要不要标准化,如何折中选择。
Pytorch代码看起来比较简单,而且和上面的公式非常符合,可以学习一下,参考自
https://www.jb51.net/article/247197.htm
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9): if not is_training: x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps) else: mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) x_hat = (x - mean) / torch.sqrt(var + eps) moving_mean = momentum * moving_mean + (1.0 - momentum) * mean moving_var = momentum * moving_var + (1.0 - momentum) * var Y = gamma * x_hat + beta return Y, moving_mean, moving_var class BatchNorm2d(nn.Module): def __init__(self, num_features): super(BatchNorm2d, self).__init__() shape = (1, num_features, 1, 1) self.gamma = nn.Parameter(torch.ones(shape)) self.beta = nn.Parameter(torch.zeros(shape)) self.register_buffer('moving_mean', torch.zeros(shape)) self.register_buffer('moving_var', torch.ones(shape)) def forward(self, x): if self.moving_mean.device != x.device: self.moving_mean = self.moving_mean.to(x.device) self.moving_var = self.moving_var.to(x.device) y, self.moving_mean, self.moving_var = batch_norm(self.training, x, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9) return y
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇博客将介绍Canny边缘检测的概念,并利用cv2.Canny()实现边缘检测,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
关于python中迭代器,生成器介绍的文章不算少数,有些写的也很透彻,但是更多的是碎片化的内容。本篇尝试用系统的介绍三者的概念和关系,需要的可以参考一下
这篇文章主要介绍了Python学习之名字,作用域,名字空间,文章围绕主题展开详细内容介绍,具有一定的参考价值,需要的小伙伴可以参考以一下
用Python怎样将日历与时间转换,方法和代码是什么?有不少朋友对此感兴趣,下面小编给大家整理和分享了相关知识和资料,易于大家学习和理解,有需要的朋友可以借鉴参考,下面我们一起来了解一下吧。
目录前言:实例1实例2前言:字符画:字符画是一系列字符的组合,我们可以把字符看作是比较大块的像素,一个字符能表现一种颜色,字符的种类越多,可以表现的颜色也越多,图片也会更有层次感
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008