PyTorch出现平方根报错是什么原因?怎样解决?
Admin 2021-08-21 群英技术资讯 784 次浏览
PyTorch出现平方根报错是什么原因?一些朋友遇到在使用PyTorch进行平方根计算时,出现平方根报错的情况,那么这时我们应该怎样解决呢?接下来小编就和大家一样探讨看看。
初步使用PyTorch进行平方根计算,通过range()创建一个张量,然后对其求平方根。
a = torch.tensor(list(range(9))) b = torch.sqrt(a)
报出以下错误:
RuntimeError: sqrt_vml_cpu not implemented for 'Long'
Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。
print(a.dtype)
torch.int64
提前将数据类型指定为浮点型, 重新执行:
b = torch.sqrt(a.to(torch.double)) print(b)
tensor([0.0000, 1.0000, 1.4142, 1.7321, 2.0000, 2.2361, 2.4495, 2.6458, 2.8284], dtype=torch.float64)
补充:pytorch10 pytorch常见运算详解
这个是矩阵(张量)每一个元素与标量进行操作。
import torch a = torch.tensor([1,2]) print(a+1) >>> tensor([2, 3])
这个就是两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积,也成为element wise。
a = torch.tensor([1,2]) b = torch.tensor([2,3]) print(a*b) print(torch.mul(a,b)) >>> tensor([2, 6]) >>> tensor([2, 6])
这个torch.mul()和*是等价的。
当然,除法也是类似的:
a = torch.tensor([1.,2.]) b = torch.tensor([2.,3.]) print(a/b) print(torch.div(a/b)) >>> tensor([0.5000, 0.6667]) >>> tensor([0.5000, 0.6667])
我们可以发现的torch.div()其实就是/, 类似的:torch.add就是+,torch.sub()就是-,不过符号的运算更简单常用。
如果我们想实现线性代数中的矩阵相乘怎么办呢?
这样的操作有三个写法:
torch.mm()
torch.matmul()
@,这个需要记忆,不然遇到这个可能会挺蒙蔽的
a = torch.tensor([[1.],[2.]]) b = torch.tensor([2.,3.]).view(1,2) print(torch.mm(a, b)) print(torch.matmul(a, b)) print(a @ b)
这是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul()可以使用。等等,多维张量怎么进行矩阵的乘法?在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:
a = torch.rand((1,2,64,32)) b = torch.rand((1,2,32,64)) print(torch.matmul(a, b).shape) >>> torch.Size([1, 2, 64, 64])
a = torch.rand((3,2,64,32)) b = torch.rand((1,2,32,64)) print(torch.matmul(a, b).shape) >>> torch.Size([3, 2, 64, 64])
这样也是可以相乘的,因为这里涉及一个自动传播Broadcasting机制,这个在后面会讲,这里就知道,如果这种情况下,会把b的第一维度复制3次 ,然后变成和a一样的尺寸,进行矩阵相乘。
print('幂运算') a = torch.tensor([1.,2.]) b = torch.tensor([2.,3.]) c1 = a ** b c2 = torch.pow(a, b) print(c1,c2) >>> tensor([1., 8.]) tensor([1., 8.])
和上面一样,不多说了。开方运算可以用torch.sqrt(),当然也可以用a**(0.5)。
在上学的时候,我们知道ln是以e为底的,但是在pytorch中,并不是这样。
pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。
import numpy as np print('对数运算') a = torch.tensor([2,10,np.e]) print(torch.log(a)) print(torch.log2(a)) print(torch.log10(a)) >>> tensor([0.6931, 2.3026, 1.0000]) >>> tensor([1.0000, 3.3219, 1.4427]) >>> tensor([0.3010, 1.0000, 0.4343])
.ceil() 向上取整
.floor()向下取整
.trunc()取整数
.frac()取小数
.round()四舍五入
.ceil() 向上取整.floor()向下取整.trunc()取整数.frac()取小数.round()四舍五入
a = torch.tensor(1.2345) print(a.ceil()) >>>tensor(2.) print(a.floor()) >>> tensor(1.) print(a.trunc()) >>> tensor(1.) print(a.frac()) >>> tensor(0.2345) print(a.round()) >>> tensor(1.)
这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。
a = torch.rand(5) print(a) print(a.clamp(0.3,0.7))
关于PyTorch平方根报错原因以及解决的方法就介绍到这,有需要的朋友可以参考,希望能对大家有帮助,想要了解更多PyTorch平方根报错的内容,大家还可以关注其他文章。
文本转载自脚本之家
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了Django给表单添加honeypot验证增加安全性的方法,帮助大家更好的理解和学习使用Django框架,感兴趣的朋友可以了解下
这篇文章主要介绍了如何利用python将Xmind用例转为Excel用例,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
这篇文章主要给大家分享python中else的使用,对新手学习和裂解else字句有一定的借鉴价值,感兴趣的朋友可以参考一下,希望大家阅读完这篇文章能有所收获,下面我们一起来学习一下吧。
本来我一直不知道怎么来更好地优化网页的性能,然后最近做python和php同类网页渲染速度比较时,意外地发现一个很简单很白痴但是 我一直没
本文主要介绍了简单介绍一下tensorflow与pytorch的相互转换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧<BR>
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008