pytorch固定BN层参数怎么操作,相关问题如何解决
Admin 2022-07-29 群英技术资讯 311 次浏览
基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。
未固定主分支BN层中的running_mean和running_var。
将需要固定的BN层状态设置为eval。
环境:torch:1.7.0
# -*- coding:utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 6, 3) self.bn1 = nn.BatchNorm2d(6) self.conv2 = nn.Conv2d(6, 16, 3) self.bn2 = nn.BatchNorm2d(16) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 5) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features def print_parameter_grad_info(net): print('-------parameters requires grad info--------') for name, p in net.named_parameters(): print(f'{name}:\t{p.requires_grad}') def print_net_state_dict(net): for key, v in net.state_dict().items(): print(f'{key}') if __name__ == "__main__": net = Net() print_parameter_grad_info(net) net.requires_grad_(False) print_parameter_grad_info(net) torch.random.manual_seed(5) test_data = torch.rand(1, 1, 32, 32) train_data = torch.rand(5, 1, 32, 32) # print(test_data) # print(train_data[0, ...]) for epoch in range(2): # training phase, 假设每个epoch只迭代一次 net.train() pre = net(train_data) # 计算损失和参数更新等 # .... # test phase net.eval() x = net(test_data) print(f'epoch:{epoch}', x)
运行结果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。
调用print_net_state_dict可以看到BN层中的参数running_mean和running_var并没在可优化参数net.parameters中
bn1.weight bn1.bias bn1.running_mean bn1.running_var bn1.num_batches_tracked
但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source
因此在training phase时对BN层显式设置eval状态:
if __name__ == "__main__": net = Net() net.requires_grad_(False) torch.random.manual_seed(5) test_data = torch.rand(1, 1, 32, 32) train_data = torch.rand(5, 1, 32, 32) # print(test_data) # print(train_data[0, ...]) for epoch in range(2): # training phase, 假设每个epoch只迭代一次 net.train() net.bn1.eval() net.bn2.eval() pre = net(train_data) # 计算损失和参数更新等 # .... # test phase net.eval() x = net(test_data) print(f'epoch:{epoch}', x)
可以看到结果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
补充:pytorch---之BN层参数详解及应用(1,2,3)(1,2)?
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对于BN层测试的均值和方差是通过统计训练的时候所有的batch的均值和方差的平均值)或者Dropout层(对于Dropout层在测试的时候所有神经元都是激活的)。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。
其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。(这里是一个可学习参数)
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性(意思就是说新的batch依赖于之前的batch的均值和方差这里使用momentum参数,参考了指数移动平均的算法EMA)。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
通常pytorch都会用到optimizer.zero_grad() 来清空以前的batch所累加的梯度,因为pytorch中Variable计算的梯度会进行累计,所以每一个batch都要重新清空一次梯度,原始的做法是下面这样的:
问题:参数non_blocking,以及pytorch的整体框架??
代码(1)
for index,data,target in enumerate(dataloader): data = data.cuda(non_blocking=True) target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye) output = model(data) loss = criterion(output,target) #清空梯度 optimizer.zero_grad() loss.backward() optimizer.step()
而这里为了模仿minibacth,我们每次batch不清0,累积到一定次数再清0,再更新权重:
for index, data, target in enumerate(dataloader): #如果不是Tensor,一般要用到torch.from_numpy() data = data.cuda(non_blocking = True) target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True) output = model(data) loss = criterion(data, target) loss.backward() if index%accumulation == 0: #用累积的梯度更新权重 optimizer.step() #清空梯度 optimizer.zero_grad()
虽然这里的梯度是相当于原来的accumulation倍,但是实际在前向传播的过程中,对于BN几乎没有影响,因为前向的BN还是只是一个batch的均值和方差,这个时候可以用pytorch中BN的momentum参数,默认是0.1,BN参数如下,就是指数移动平均
x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了python倒序for循环实例,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
这篇文章主要为大家介绍了python人工智能tensorflow构建循环神经网络RNN,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
python集合怎样理解?集合是python中很基础的一个内容,这篇文章主要给大家分享的是集合的基本信息和集合的基本操作,有这方面学习需要的朋友可以参考。
文本主要给给大家分享的是关于python匿名函数的内容,匿名函数也就是没有名字的函数,在python中还是比较实用的,因此分享给大家作参考,下面我们就一起来学习一下python匿名函数吧。
random模块在python中起到的是生成随机数的作用,random模块中choice()可以从序列中获取一个随机元素,并返回一个(列表,元组或字符串中的)随机项。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008