pytorch固定BN层参数怎么操作,相关问题如何解决
Admin 2022-07-29 群英技术资讯 454 次浏览
基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。
未固定主分支BN层中的running_mean和running_var。
将需要固定的BN层状态设置为eval。
环境:torch:1.7.0
# -*- coding:utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 6, 3) self.bn1 = nn.BatchNorm2d(6) self.conv2 = nn.Conv2d(6, 16, 3) self.bn2 = nn.BatchNorm2d(16) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 5) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features def print_parameter_grad_info(net): print('-------parameters requires grad info--------') for name, p in net.named_parameters(): print(f'{name}:\t{p.requires_grad}') def print_net_state_dict(net): for key, v in net.state_dict().items(): print(f'{key}') if __name__ == "__main__": net = Net() print_parameter_grad_info(net) net.requires_grad_(False) print_parameter_grad_info(net) torch.random.manual_seed(5) test_data = torch.rand(1, 1, 32, 32) train_data = torch.rand(5, 1, 32, 32) # print(test_data) # print(train_data[0, ...]) for epoch in range(2): # training phase, 假设每个epoch只迭代一次 net.train() pre = net(train_data) # 计算损失和参数更新等 # .... # test phase net.eval() x = net(test_data) print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。
调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中
bn1.weight bn1.bias bn1.running_mean bn1.running_var bn1.num_batches_tracked
但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source
因此在training phase时对BN层显式设置eval状态:
if __name__ == "__main__": net = Net() net.requires_grad_(False) torch.random.manual_seed(5) test_data = torch.rand(1, 1, 32, 32) train_data = torch.rand(5, 1, 32, 32) # print(test_data) # print(train_data[0, ...]) for epoch in range(2): # training phase, 假设每个epoch只迭代一次 net.train() net.bn1.eval() net.bn2.eval() pre = net(train_data) # 计算损失和参数更新等 # .... # test phase net.eval() x = net(test_data) print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
补充:pytorch---之BN层参数详解及应用(1,2,3)(1,2)?
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对于BN层测试的均值和方差是通过统计训练的时候所有的batch的均值和方差的平均值)或者Dropout层(对于Dropout层在测试的时候所有神经元都是激活的)。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。
其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。(这里是一个可学习参数)
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性(意思就是说新的batch依赖于之前的batch的均值和方差这里使用momentum参数,参考了指数移动平均的算法EMA)。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
通常pytorch都会用到optimizer.zero_grad() 来清空以前的batch所累加的梯度,因为pytorch中Variable计算的梯度会进行累计,所以每一个batch都要重新清空一次梯度,原始的做法是下面这样的:
问题:参数non_blocking,以及pytorch的整体框架??
代码(1)
for index,data,target in enumerate(dataloader): data = data.cuda(non_blocking=True) target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye) output = model(data) loss = criterion(output,target) #清空梯度 optimizer.zero_grad() loss.backward() optimizer.step()
而这里为了模仿minibacth,我们每次batch不清0,累积到一定次数再清0,再更新权重:
for index, data, target in enumerate(dataloader): #如果不是Tensor,一般要用到torch.from_numpy() data = data.cuda(non_blocking = True) target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True) output = model(data) loss = criterion(data, target) loss.backward() if index%accumulation == 0: #用累积的梯度更新权重 optimizer.step() #清空梯度 optimizer.zero_grad()
虽然这里的梯度是相当于原来的accumulation倍,但是实际在前向传播的过程中,对于BN几乎没有影响,因为前向的BN还是只是一个batch的均值和方差,这个时候可以用pytorch中BN的momentum参数,默认是0.1,BN参数如下,就是指数移动平均
x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
Shell 是一个用 C 语言编写的程序,它是用户使用 Linux 的桥梁。Shell 既是一种命令语言,又是一种程序设计语言。
这篇文章主要为大家介绍了python编程中Flask框架简单使用教程,有需要的朋友可以借鉴参考下希望能够有所帮助,祝大家多多进步早日升职加薪
这篇文章主要为大家介绍了python密码学Base64编码和解码教程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
我们需要使用到图片素材的场景很多,但是很多素材都有水印,而一张张去除水印是工作量大。对此,这篇文章小编就给大家分享如何用python实现图片批量去水印的方法,下面我们一起来看看是怎样做的吧。
这篇文章主要介绍了Python和java 如何相互调用,下面文章见到那的对Python和java 相互调用的方法做了个小总结,具有一定的参考价值,需要的小伙伴可以参考一下,希望对你有所帮助
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008