在pytorch中对非叶节点的变量计算梯度实例

 更新时间:2020-01-11 00:00:25   作者:佚名   我要评论(0)

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。

注册hook函数

Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上这篇在pytorch中对非叶节点的变量计算梯度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:

  • pytorch的梯度计算以及backward方法详解
  • PyTorch中的Variable变量详解

相关文章

  • 在pytorch中对非叶节点的变量计算梯度实例

    在pytorch中对非叶节点的变量计算梯度实例

    在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点
    2020-01-11
  • 详解Linux环境变量配置全攻略

    详解Linux环境变量配置全攻略

    在自定义安装软件的时候,经常需要配置环境变量,下面列举出各种对环境变量的配置方法。 下面所有例子的环境说明如下: 系统:Ubuntu 14.0 用户名:uusama 需
    2020-01-11
  • Java JVM程序指令码实例解析

    Java JVM程序指令码实例解析

    这篇文章主要介绍了Java JVM程序指令码实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 java程序转化
    2020-01-11
  • Pytorch 保存模型生成图片方式

    Pytorch 保存模型生成图片方式

    三通道数组转成彩色图片 img=np.array(img1) img=img.reshape(3,img1.shape[2],img1.shape[3]) img=(img+0.5)*255##img做过归一化处理,【-0.5,0.
    2020-01-11
  • Ubuntu16.04安装python3.6.5步骤详解

    Ubuntu16.04安装python3.6.5步骤详解

    下载python3.6.5安装包 1. 上传安装包。打开终端,利用命令cd 进入文件所在文件夹里 python@ubuntu:~/workspace$pwd /home/python/workspace 2. 解压文件 t
    2020-01-11
  • 解决Pytorch 加载训练好的模型 遇到的error问题

    解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- pytorch 已经非常
    2020-01-11
  • Linux下如何永久修改主机名的方法步骤

    Linux下如何永久修改主机名的方法步骤

    想修改自己的主机名,那你可以根据下面的步骤实现 使用hostname 使用hostname命令只能临时改变我们的主机名,当我们重启之后主机名还会恢复成原来的 # hostname n
    2020-01-11
  • 详解mysql8.018在linux上安装与配置过程

    详解mysql8.018在linux上安装与配置过程

    windows下安装介绍:去看看–》mysql8.018在windows下安装介绍 Linux平台: 以下操作以mysql 8.0.18,系统为Ubuntu 16.04.6 LTS (GNU/Linux 4.4.0-142-generic x86_
    2020-01-11
  • webpack proxy 使用(代理的使用)

    webpack proxy 使用(代理的使用)

    为什么要写篇文章 这两天的开发中遇到一些需要代理才能解决的问题, 在这里记录一下, 方便以后的查阅. 为什么要用代理 跨域 在开发过程中, 我们的开发环境一般
    2020-01-11
  • 使用pytorch完成kaggle猫狗图像识别方式

    使用pytorch完成kaggle猫狗图像识别方式

    kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、深度学习爱好者学习之
    2020-01-11

最新评论