如何理解pytorch_detach怎么实现切断网络反传
Admin 2022-06-15 群英技术资讯 379 次浏览
官方文档中,对这个方法是这么介绍的。
detach = _add_docstr(_C._TensorBase.detach, r""" Returns a new Tensor, detached from the current graph. The result will never require gradient. .. note:: Returned Tensor uses the same data tensor as the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks. """)
返回一个新的从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch from torch.nn import init t1 = torch.tensor([1., 2.],requires_grad=True) t2 = torch.tensor([2., 3.],requires_grad=True) v3 = t1 + t2 v3_detached = v3.detach() v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值 print(v3, v3_detached) # v3 中tensor 的值也会改变 print(v3.requires_grad,v3_detached.requires_grad) ''' tensor([4., 7.], grad_fn=<AddBackward0>) tensor([4., 7.]) True False '''
在pytorch中通过拷贝需要切断位置前的tensor实现这个功能。tensor中拷贝的函数有两个,一个是clone(),另外一个是copy_(),clone()相当于完全复制了之前的tensor,他的梯度也会复制,而且在反向传播时,克隆的样本和结果是等价的,可以简单的理解为clone只是给了同一个tensor不同的代号,和‘='等价。所以如果想要生成一个新的分开的tensor,请使用copy_()。
不过对于这样的操作,pytorch中有专门的函数――detach()。
用户自己创建的节点是leaf_node(如图中的abc三个节点),不依赖于其他变量,对于leaf_node不能进行in_place操作.根节点是计算图的最终目标(如图y),通过链式法则可以计算出所有节点相对于根节点的梯度值.这一过程通过调用root.backward()就可以实现.
因此,detach所做的就是,重新声明一个变量,指向原变量的存放位置,但是requires_grad为false.更深入一点的理解是,计算图从detach过的变量这里就断了, 它变成了一个leaf_node.即使之后重新将它的requires_node置为true,它也不会具有梯度.
(0.4之后),tensor和variable合并,tensor具有grad、grad_fn等属性;
默认创建的tensor,grad默认为False, 如果当前tensor_grad为None,则不会向前传播,如果有其它支路具有grad,则只传播其它支路的grad
# 默认创建requires_grad = False的Tensor x = torch.ones(1) # create a tensor with requires_grad=False (default) print(x.requires_grad) # out: False # 创建另一个Tensor,同样requires_grad = False y = torch.ones(1) # another tensor with requires_grad=False # both inputs have requires_grad=False. so does the output z = x + y # 因为两个Tensor x,y,requires_grad=False.都无法实现自动微分, # 所以操作(operation)z=x+y后的z也是无法自动微分,requires_grad=False print(z.requires_grad) # out: False # then autograd won't track this computation. let's verify! # 因而无法autograd,程序报错 # z.backward() # out:程序报错:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn # now create a tensor with requires_grad=True w = torch.ones(1, requires_grad=True) print(w.requires_grad) # out: True # add to the previous result that has require_grad=False # 因为total的操作中输入Tensor w的requires_grad=True,因而操作可以进行反向传播和自动求导。 total = w + z # the total sum now requires grad! total.requires_grad # out: True # autograd can compute the gradients as well total.backward() print(w.grad) #out: tensor([ 1.]) # and no computation is wasted to compute gradients for x, y and z, which don't require grad # 由于z,x,y的requires_grad=False,所以并没有计算三者的梯度 z.grad == x.grad == y.grad == None # True
import torch.nn.functional as F # With square kernels and equal stride filters = torch.randn(8,4,3,3) weiths = torch.nn.Parameter(torch.randn(8,4,3,3)) inputs = torch.randn(1,4,5,5) out = F.conv2d(inputs, weiths, stride=2,padding=1) print(out.shape) con2d = torch.nn.Conv2d(4,8,3,stride=2,padding=1) out_2 = con2d(inputs) print(out_2.shape)
补充:Pytorch-detach()用法
神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整。
或者训练部分分支网络,并不让其梯度对主网络的梯度造成影响.这时候我们就需要使用detach()函数来切断一些分支的反向传播.
返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad.
即使之后重新将它的requires_grad置为true,它也不会具有梯度grad.这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播.
注意:
使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变。
比如正常的例子是:
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a) print(a.grad) out = a.sigmoid() out.sum().backward() print(a.grad)
输出
tensor([1., 2., 3.], requires_grad=True)
None
tensor([0.1966, 0.1050, 0.0452])
1.1 当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward():
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) out = a.sigmoid() print(out) #添加detach(),c的requires_grad为False c = out.detach() print(c) #这时候没有对c进行更改,所以并不会影响backward() out.sum().backward() print(a.grad) '''返回: None tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) tensor([0.7311, 0.8808, 0.9526]) tensor([0.1966, 0.1050, 0.0452]) '''
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了Python实现迷宫生成器的详细代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
今天给大家介绍Python中的pathlib库的操作方法,pathlib 是Python内置库,pathlib库对于目录路径的操作更简洁也更贴近 Pythonic(Python代码风格的),对Python pathlib库相关知识感兴趣的朋友一起看看吧
目录题目描述示例 2:示例 3:单向构造(哈希表计数)双向构造(双指针)最后题目描述 这是 LeetCode 上的 1743. 从相邻元素对还原数组 ,难度为 中等。 Tag : 「哈希表」、「双指针」、「模拟」 存在一个由 n 个不同元素组成的整数数组 nums...
这篇文章主要介绍了通过Python实现创建语音识别控制系统,能利用语音识别识别说出来的文字,根据文字的内容来控制图形移动,感兴趣的同学可以关注一下
这篇文章主要介绍了解决pytorch rnn 变长输入序列的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008