SGD过程中的噪声如何帮助避免局部极小值和鞍点?

来自UCBerkeleyRISELab的本科研究员NoahGolmant发表博客,从理论的角度分析了损失函数的结构,并据此解释随机梯度下降(SGD)中的噪声如何帮助避免局部极小值和鞍点,为设计和改良深度学习架构提供了很有用的参考视角 。
当我们着手训练一个很酷的机器学习模型时,最常用的方法是随机梯度下降法(SGD) 。随机梯度下降在高度非凸的损失表面上远远超越了朴素梯度下降法 。这种简单的爬山法技术已经主导了现代的非凸优化 。然而,假的局部最小值和鞍点的存在使得分析工作更加复杂 。理解当去除经典的凸性假设时,我们关于随机梯度下降(SGD)动态的直觉会怎样变化是十分关键的 。向非凸环境的转变催生了对于像动态系统理论、随机微分方程等框架的使用,这为在优化解空间中考虑长期动态和短期随机性提供了模型 。
在这里,我将讨论在梯度下降的世界中首先出现的一个麻烦:噪声 。随机梯度下降和朴素梯度下降之间唯一的区别是:前者使用了梯度的噪声近似 。这个噪声结构最终成为了在背后驱动针对非凸问题的随机梯度下降算法进行「探索」的动力 。
mini-batch噪声的协方差结构
介绍一下我们的问题设定背景 。假设我想要最小化一个包含N个样本的有限数据集上的损失函数f:R^n→R 。对于参数x∈R^n,我们称第i个样本上的损失为f_i(x) 。现在,N很可能是个很大的数,因此,我们将通过一个小批量估计(mini-batchestimate)g_B:来估计数据集的梯度g_N: 。其中,B?{1,2,…,N}是一个大小为m的mini-batch 。尽管g_N本身就是一个关于梯度?f(x)的带噪声估计,结果表明,mini-batch抽样可以生成带有有趣的协方差结构的估计 。
引理1(Chaudhari&Soatto定理:https://arxiv.org/abs/1710.11029):在回置抽样(有放回的抽样)中,大小为m的mini-batch的方差等于Var(g_B)=1/mD(x),其中
该结果意味着什么呢?在许多优化问题中,我们根本的目标是最大化一些参数配置的似然 。因此,我们的损失是一个负对数似然 。对于分类问题来说,这就是一个交叉熵 。在这个例子中,第一项是对于(负)对数似然的梯度的协方差的估计 。这就是观测到的Fisher信息 。当N趋近于正无穷时,它就趋向于一个Fisher信息矩阵,即相对熵(KL散度)的Hessian矩阵 。但是KL散度是一个与我们想要最小化的交叉熵损失(负对数似然)相差甚远的常数因子 。
因此,mini-batch噪声的协方差与我们损失的Hessian矩阵渐进相关 。事实上,当x接近一个局部最小值时,协方差就趋向于Hessian的缩放版本 。
绕道Fisher信息
在我们继续详细的随机梯度下降分析之前,让我们花点时间考虑Fisher信息矩阵I(x)和Hessian矩阵?^2f(x)之间的关系 。I(x)是对数似然梯度的方差 。方差与损失表面的曲率有什么关系呢?假设我们处在一个严格函数f的局部最小值,换句话说,I(x?)=?^2f(x?)是正定的 。I(x)引入了一个x?附近的被称为「Fisher-Raometric」的度量指标:d(x,y)=√[(x?y)^TI(x?)(x?y)] 。有趣的是,参数的Fisher-Rao范数提供了泛化误差的上界(https://arxiv.org/abs/1711.01530) 。这意味着我们可以对平坦极小值的泛化能力更有信心 。
回到这个故事中来
接下来我们介绍一些关于随机梯度下降动态的有趣猜想 。让我们做一个类似中心极限定理的假设,并且假设我们可以将估计出的g_B分解成「真实」的数据集梯度和噪声项:g_B=g_N+(1√B)n(x),其中n(x)~N(0,D(x)) 。此外,为了简单起见,假设我们已经接近了极小值,因此D(x)≈?^2f(x) 。n(x)在指数参数中有一个二次形式的密度ρ(z):
这表明,Hessian矩阵的特征值在决定被随机梯度下降认为是「稳定」的最小值时起重要的作用 。当损失处在一个非常「尖锐」(二阶导很大)的最小值,并且此处有许多绝对值大的、正的特征值时,我很可能会加入一些把损失从朴素梯度下降的吸引域中「推出来」的噪声 。类似地,对于平坦极小值,损失更有可能「稳定下来」 。我们可以用下面的技巧做到这一点:


特别声明:本站内容均来自网友提供或互联网,仅供参考,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。