梯度下降法是用来计算函数最小值的。它的思路很简单,想象在山顶放了一个球,一松手它就会顺着山坡最陡峭的地方滚落到谷底:

1

凸函数图像看上去就像上面的山谷,如果运用梯度下降法的话,就可以通过一步步的滚动最终来到谷底,也就是找到了函数的最小值。

动机

先解释下为什么要有梯度下降法?其实最简单的二维凸函数是抛物线f(x)=x2f(x)=x^2,很容易通过解方程f(x)=0f'(x)=0求出最小值在x=0x=0处:

1

只是有一些凸函数,比如下面这个二元函数(该函数实际上是逻辑回归的经验误差函数,在监督式学习中确实需要求它的最小值):

\begin{align}\begin{aligned} f(w_0,w_1) &=\frac{1}{6}\Big[\ln\Big(1+e^{w_0+2w_1}\Big)+\ln\Big(1+e^{-w_0-7w_1}\Big)\\ &\qquad\quad+\ln\Big(1+e^{-w_0-4w_1}\Big)+\ln\Big(1+e^{w_0+w_1}\Big)\\ &\qquad\quad+\ln\Big(1+e^{-w_0-5w_1}\Big)+\ln\Big(1+e^{w_0+4.5w_1}\Big)\Big] \end{aligned}\end{align}

要求它的最小值点就需要解如下方程组:

\begin{align}\begin{cases} \begin{aligned} \frac{\partial f}{\partial w_0} &=\frac{1}{6}\Big[\frac{e^{w_0+2w_1}}{1+e^{w_0+2w_1}}-\frac{e^{-w_0-7w_1}}{1+e^{-w_0-7w_1}}\\ &\qquad\quad-\frac{e^{-w_0-4w_1}}{1+e^{-w_0-4w_1}}+\frac{e^{w_0+w_1}}{1+e^{w_0+w_1}}\\ &\qquad\quad-\frac{e^{-w_0-5w_1}}{1+e^{-w_0-5w_1}}+\frac{e^{w_0+4.5w_1}}{1+e^{w_0+4.5w_1}}\Big]=0 \end{aligned}\\ \begin{aligned} \frac{\partial f}{\partial w_1} &=\frac{1}{6}\Big[\frac{2e^{w_0+2w_1}}{1+e^{w_0+2w_1}}-\frac{7e^{-w_0-7w_1}}{1+e^{-w_0-7w_1}}\\ &\qquad\quad-\frac{4e^{-w_0-4w_1}}{1+e^{-w_0-4w_1}}+\frac{e^{w_0+w_1}}{1+e^{w_0+w_1}}\\ &\qquad\quad-\frac{5e^{-w_0-5w_1}}{1+e^{-w_0-5w_1}}+\frac{4.5e^{w_0+4.5w_1}}{1+e^{w_0+4.5w_1}}\Big]=0 \end{aligned} \end{cases} \end{align}

这个方程组实在太复杂了,直接求解难度太高,好在f(w0,w1)f(w_0,w_1)的图像就像一座山谷:

1

所以可以用梯度下降法来找到f(w0,w1)f(w_0,w_1)的谷底,也就是最小值。

最简单的例子

梯度下降法在本文不打算进行严格地证明和讲解,主要通过一些例子来讲解,先从最简单的凸函数f(x)=x2f(x)=x^2开始讲起。

梯度向量

假设起点在x0=10x_0=10处,也就是将球放在x0=10x_0=10

1

它的梯度为 1 维向量:

\begin{align}\nabla f(x_0)=f'(x_0)\boldsymbol{i}=\Big(f'(x_0)\Big)=\left(2x|_{x_0=10}\right)=(20)\end{align}

这是在xx轴上的向量,它指向函数值增长最快的方向,而f(x0)-\nabla f(x_0)就指向减少最快的方向:

1

x0x_0也看作 1 维向量(x0)(x_0),通过和f(x0)-\nabla f(x_0)相加,可以将之向f(x0)-\nabla f(x_0)移动一段距离得到新的向量(x1)(x_1)

\begin{align}(x_1)=(x_0)-\eta \nabla f(x_0)\end{align}

其中η\eta称为步长,通过它可以控制移的动距离,本节设η=0.2\eta=0.2,那么:

\begin{align}(x_1)=(x_0)-\eta \nabla f(x_0)=(10)-0.2\times (20)=(6)\end{align}

此时小球(也就是起点)下降到了x1=6x_1=6这个位置:

1

迭代

x1x_1的梯度为:

\begin{align}\nabla f(x_1)=f'(x_1)\boldsymbol{i}=\Big(f'(x_1)\Big)=\left(2x|_{x_1=6}\right)=(12)\end{align}

继续沿着梯度的反方向走:

\begin{align}(x_2)=(x_1)-\eta \nabla f(x_1)=(6)-0.2\times(12) = (3.6)\end{align}

小球就滚到了更低的位置:

1

重复上述过程到第 10 次,小球基本上就到了最低点,即有x100x_{10}\approx 0

1

梯度下降法

把每一次的梯度向量f\nabla f的模长列f||\nabla f||出来,可以看到是在不断减小的,因此这种方法称为梯度下降法:

\begin{align} \begin{array}{c|c|c} \hline \quad\quad&x_0&x_1&x_2&x_3&x_4&x_5&x_6&x_7&x_8&x_9&x_{10}\\ \hline\\ ||\nabla f||&20&12&7.2&4.32&2.59&1.56&0.93&0.56&0.34&0.2&0.12\\ \\ \hline \end{array} \end{align}

这也比较好理解,当最终趋向于 0 时有:

\begin{align}||\nabla f||=0\implies\nabla f=0\implies f'(x)=0\end{align}

所以梯度下降法求出来的就是最小值(或者在附近)。

步长

上面谈到了可以通过步长η\eta来控制每次移动的距离,下面来看看不同步长对最终结果的影响。

过小

如果设η=0.01\eta=0.01就过于小了,迭代 20 次后离谷底还很远,实际上 100 次后都无法到达谷底:

1

合适

上面例子中用的η=0.2\eta=0.2是较为合适的步长,10 次就差不多找到了最小值:

1

较大

如果令η=1\eta=1,这个时候会来回震荡(下图看上去只有两个点,实际上在这两个点之间来来回回):

1

过大

继续加大步长,比如令η=1.1\eta=1.1,反而会越过谷底,不断上升:

1

总结

总结下,不同的步长η\eta,随着迭代次数的增加,会导致被优化函数f(x)f(x)的值有不同的变化:

1

寻找合适的步长η\eta是个手艺活,在工程中可以将上图画出来,根据图像来手动调整:

  • f(x)f(x)往上走(红线),自然是η\eta过大,需要调低
  • f(x)f(x)一开始下降特别急,然后就几乎没有变化(棕线),可能是η\eta较大,需要调低
  • f(x)f(x)几乎是线性变化(蓝线),可能是η\eta过小,需要调高

三维的例子

原理都介绍完了,下面再通过一个三维的例子来加强对梯度下降法的理解。假设函数为:

\begin{align}f(\boldsymbol{x})=x_1^2+2x_2^2\end{align}

其图像及等高线如下(等高线中心的蓝点表示最小值):

1

下面用梯度下降法来寻找最小值。

前进一步

设初始点为x0=(3.5,3.5)\boldsymbol{x}_0=(-3.5,-3.5),此时梯度为:

\begin{align}\nabla f(\boldsymbol{x}_0)=(\frac{\partial f(\boldsymbol{x}_0)}{\partial x_1},\frac{\partial f(\boldsymbol{x}_0)}{\partial x_2})=(2x_1, 4x_2)\Big |_{x_1=-3.5,x_2=-3.5}=(-7, -14)\end{align}

令步长η=0.1\eta=0.1,那么下一个点为:

\begin{align} \begin{aligned} \boldsymbol{x}_1 &=\boldsymbol{x}_0-\eta\nabla f(\boldsymbol{x}_0)\\ &=(-3.5,-3.5)-0.1\times(-7,-14)=(-2.8,-2.1) \end{aligned} \end{align}

可以看到向最小值方向前进了一步:

1

迭代

同样的方法找到下一个点:

\begin{align}\begin{aligned} \boldsymbol{x}_2 &=\boldsymbol{x}_1-\eta\nabla f(\boldsymbol{x}_1)\\ &=(-2.8,-2.1)-0.1\times(-5.6,-8.4)=(-2.24,-1.26) \end{aligned}\end{align}

此时又向最小值靠近了:

1

如此迭代20次后,差不多找到了最小值:

1