反向传播 Backpropagation

5. 反向传播 Backpropagation

5.1 反向传播 Backpropagation

优化算法中,无一例外需要计算梯度,由于直接推导梯度的解析式过于繁杂,数值计算梯度又不够精确,所以我们选择反向传播(backpropagation)来计算梯度

将复杂的表达式表示为计算图(computational graph),不妨设其中一个节点为\(z=f(x)\)

Upstream gradient: \(\displaystyle\frac{\partial L}{\partial z}\)

Local gradient: \(\displaystyle \frac{\partial z}{\partial x}\)

那么由链式法则,downstream gradient 就是 \[ \frac{\partial L}{\partial x} = \frac{\partial z}{\partial x}\frac{\partial L}{\partial z} \] 如此反复便能得到每一个参数的梯度

5.2 向量求导 Vector Derivatives

\[ x \in \mathbb R,y\in \mathbb R, \frac{\partial y}{\partial x}\in \mathbb R \]

\[ x \in \mathbb R^N,y\in \mathbb R, \frac{\partial y}{\partial x}\in \mathbb R^N, \left(\frac{\partial y}{\partial x}\right)_n = \frac{\partial y}{\partial x_n} \]

\[ x \in \mathbb R^N,y\in \mathbb R^M, \frac{\partial y}{\partial x}\in \mathbb R^{N\times M}, \left(\frac{\partial y}{\partial x}\right)_{n,m} = \frac{\partial y_m}{\partial x_n} \]

5.3 矩阵相乘的反向传播

在computational graph中,对于矩阵相乘的节点: \[ y=xw,x \in \mathbb R^{N \times D}, y \in \mathbb R^{N\times M}, w \in \mathbb R^{M\times D} \] 若直接进行反向传播计算,Jacobians过于庞大且稀疏,产生一些不必要的内存开销

事实上,如下公式会简化计算: \[ \frac{\partial L}{\partial x} = \left(\frac{\partial L}{\partial y}\right)w^{\mathsf T} \] 以下是该公式的一些简单应用:

线性分类器

\[ z=Wx+b \]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Backpropagation of linear classifier
import torch

class Linear(object):
@staticmethod
def backward(dout, cache):
"""
Inputs:
- dout: Upstream derivative, of shape (N, M)
- cache: Tuple of:
- x: Input data, of shape (N, D)
- w: Weights, of shape (D, M)
- b: Biases, of shape (M,)
Returns a tuple of:
- dx: Gradient with respect to x, of shape (N, D)
- dw: Gradient with respect to w, of shape (D, M)
- db: Gradient with respect to b, of shape (M,)
"""
x, w, b = cache

dx = dout.mm(w.t())
dw = x.t().mm(dout)
db = dout.sum(dim=0)

return dx, dw, db

ReLU

\[ \mathrm{ReLU}(z) = \max (0,z) \]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Backpropagation of ReLU function
import torch

class ReLU(object):
@staticmethod
def backward(dout, cache):
"""
Input:
- dout: Upstream derivatives, of any shape
- cache: Input x, of same shape as dout
Returns:
- dx: Gradient with respect to x
"""
x = cache
dx = dout * (x > 0)
return dx

Dropout

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Backpropagation of dropout
import torch

class dropout(object):
@staticmethod
def backward(dout, cache):
"""
Inputs:
- dout: Upstream derivatives, of any shape
- cache: (dropout_param, mask) from Dropout.forward.
Returns:
- dx: Gradient with respect to x
"""
mask = cache
dx = dout * mask
return dx

反向传播 Backpropagation
http://hmnkapa.github.io/2024/07/23/反向传播/
作者
呼姆奴库
发布于
2024年7月23日
许可协议