pytorch半精度模型部署的具体方法是啥
Admin 2022-07-26 群英技术资讯 292 次浏览
pytorch作为深度学习的计算框架正得到越来越多的应用.
我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.
在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.
但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍
在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可
另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.
补充:pytorch 使用amp.autocast半精度加速训练
pytorch 1.6+
根据官方提供的方法,
答案就是autocast + GradScaler。
答案:autocast + GradScaler。
正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的
from torch.cuda.amp import autocast as autocast # 创建model,默认是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # 前向过程(model + loss)开启 autocast with autocast(): output = model(input) loss = loss_fn(output, target) # 反向传播在autocast上下文之外 loss.backward() optimizer.step()
GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。
因此PyTorch中经典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast # 创建model,默认是torch.FloatTensor model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) # 在训练最开始之前实例化一个GradScaler对象 scaler = GradScaler() for epoch in epochs: for input, target in data: optimizer.zero_grad() # 前向过程(model + loss)开启 autocast with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。
要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?
class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, input_data_c1): with autocast(): # code return
只要把forward里面的代码用autocast代码块方式运行就好啦!
如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:
1、matmul
2、addbmm
3、addmm
4、addmv
5、addr
6、baddbmm
7、bmm
8、chain_matmul
9、conv1d
10、conv2d
11、conv3d
12、conv_transpose1d
13、conv_transpose2d
14、conv_transpose3d
15、linear
16、matmul
17、mm
18、mv
19、prelu
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要为大家介绍了python中的变量,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
这篇文章主要介绍了Python机器学习三大件之一numpy,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好地帮助哟.需要的朋友可以参考下
集合(set)是一个无序的不重复元素序列。因此在每次运行的时候集合的运行结果的内容都是相同的,但元素的排列顺序却不是固定的,所以本章中部分案例的运行结果会出现与给出结果不同的情况(运行结果不唯一)可以使用大括号{}或者set()函数创建集合,注意:创建一个空集合必须用set()而不是{},因为{}是用来创建一个空字典
本文主要介绍了Django实现视频播放的具体示例,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下<BR>
这篇文章主要介绍了Python实现层次分析法及自调节层次分析法的示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008