文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
0. 测试环境
Python 3.6.9, Pytorch 1.5.0
1. 基本概念
Tensor是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32。
- 示例一
1 | >>> a = torch.tensor([1.0]) |
Tensor中只有一个数字时,使用torch.Tensor.item()可以得到一个Python数字。requires_grad为True时,表示需要计算Tensor的梯度。requires_grad=False可以用来冻结部分网络,只更新另一部分网络的参数。
- 示例二
1 | >>> a = torch.tensor([1.0, 2.0]) |
a.data返回的是一个新的Tensor对象b,a, b的id不同,说明二者不是同一个Tensor,但b与a共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b的元素时,a的元素也对应修改。
2. requiresgrad()与detach()
1 | >>> a = torch.tensor([1.0, 2.0]) |
requires_grad_()
requires_grad_()函数会改变Tensor的requires_grad属性并返回Tensor,修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=True。requires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。
detach()
detach()函数会返回一个新的Tensor对象b,并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。b与a共享数据的存储空间,二者指向同一块内存。
注:共享内存空间只是共享的数据部分,a.grad与b.grad是不同的。
3. torch.no_grad()
torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
1 | a = torch.tensor([1.0, 2.0], requires_grad=True) |
上面的例子中,当a的requires_grad=True时,不使用torch.no_grad(),c.requires_grad为True,使用torch.no_grad()时,b.requires_grad为False,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True会占用更多的计算资源及存储资源。
4. 总结
requires_grad_()会修改Tensor的requires_grad属性。
detach()会返回一个与计算图分离的新Tensor,新Tensor不会在反向传播中计算梯度,会在特定场合使用。
torch.no_grad()更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。