Pytorch计算图
1. 计算图概念
计算图是一种特殊的有向无环图(DAG),用于记录算子(运算)与变量(tensor)之间的关系。一般用矩形表示算子。
算子(运算)包括加减乘除、开方、幂指对、三角函数等可求导运算。
变量(tensor)包括叶子节点和非叶子节点。
2. 如何在计算图中反向传播
使用backward()函数反向传播计算tensor的梯度,并不是计算所有tensor的梯度,只计算满足如下条件的tensor:
1、叶子节点
2、requires_grad=True
3、依赖该tensor的所有tensor的requires_grad=True
举例:
1)
1 | import torch |
1 | x.requires_grad = True |
2)
1 | import torch |
1 | x.requires_grad = True |
3)
1 | import torch |
1 | x.requires_grad = True |
注:detach()后并不会影响到tensor的值。
4)
1 | import torch |
会有如下报错
1 | RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. |
原因:一个计算图只能计算一次反向传播。当反向传播后,这个计算图的内存就会被释放。
更改为如下即可
1 | import torch |
参考
- Pytorch: detach 和 retain_graph,和 GAN的原理解析
- requires_grad,requires_grad_(),grad_fn区别
- 叶子节点和tensor的requires_grad参数
- Pytorch-detach()用法