Pytorch计算图

1. 计算图概念

计算图是一种特殊的有向无环图(DAG),用于记录算子(运算)与变量(tensor)之间的关系。一般用矩形表示算子。

算子(运算)包括加减乘除、开方、幂指对、三角函数等可求导运算。

变量(tensor)包括叶子节点和非叶子节点。

2. 如何在计算图中反向传播

使用backward()函数反向传播计算tensor的梯度,并不是计算所有tensor的梯度,只计算满足如下条件的tensor:

1、叶子节点

2、requires_grad=True

3、依赖该tensor的所有tensor的requires_grad=True

举例:

1)

1
2
3
4
5
6
import torch

x = torch.tensor(2.0,requires_grad=True)
y = x ** 2
z = y * 4
z.backward()

1

1
2
3
4
5
6
7
x.requires_grad = True
x.is_leaf = True
x.grad = tensor(16.)

y.requires_grad = True
y.is_leaf = Fals
y.grad = None # 默认情况下,对于非叶节点tensor,backward()之后梯度会被释放回收内存

2)

1
2
3
4
5
6
7
import torch

x = torch.tensor(2.0,requires_grad=True)
y = x ** 2
z = y * 4
y.retain_grad() # 保留y的梯度
z.backward()
1
2
3
4
5
6
7
x.requires_grad = True
x.is_leaf = True
x.grad = tensor(16.)

y.requires_grad = True
y.is_leaf = Fals
y.grad = tensor(4.)

3)

2

1
2
3
4
5
6
7
import torch

x = torch.tensor(2.0,requires_grad=True)
y = x ** 2
z = y * 4
y_ = y.detach() #返回一个新的tensor,该tensor从当前计算图中分离,但仍指向原变量的存放位置
z.backward()
1
2
3
4
5
6
7
8
9
10
11
x.requires_grad = True
x.is_leaf = True
x.grad = tensor(16.)

y.requires_grad = True
y.is_leaf = False
y.grad = None

y_.requires_grad = False
y_.is_leaf = True
y_.grad = None # 不在当前计算图中

注:detach()后并不会影响到tensor的值。

4)

1

1
2
3
4
5
6
7
import torch

x = torch.tensor(2.0,requires_grad=True)
y = x ** 2
z = y * 4
z.backward()
z.backward()

会有如下报错

1
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

原因:一个计算图只能计算一次反向传播。当反向传播后,这个计算图的内存就会被释放。

更改为如下即可

1
2
3
4
5
6
7
import torch

x = torch.tensor(2.0,requires_grad=True)
y = x ** 2
z = y * 4
z.backward(retain_graph=True)
z.backward()

参考

  1. Pytorch: detach 和 retain_graph,和 GAN的原理解析
  2. requires_grad,requires_grad_(),grad_fn区别
  3. 叶子节点和tensor的requires_grad参数
  4. Pytorch-detach()用法