如果面试官问【聊一下RNN中的梯度消失】
盲猜很多同学的回答可以简化成这样形式【由于网络太深,梯度反向传播会出现连乘效应,从而出现梯度消失】
这样的回答,如果用在普通网络,类似MLP,是没有什么问题的,但是放在RNN中,是错误的。
RNN的梯度是一个和,是近距离梯度和远距离梯度的和;
RNN中的梯度消失的含义是远距离的梯度消失,而近距离梯度不会消失,从而导致总的梯度被近的梯度主导,同时总的梯度不会消失。
这也是为什么RNN模型能以学到远距离依赖关系。
简单的解释一下原因。
首先,我们要明白一点,RNN是共享一套参数的(输入参数,输出参数,隐层参数),这一点非常的重要。
当然,我们在理解RNN的时候,会把RNN按照时间序列展开多个模块,可能会认为是多套参数,这个是不对的哈。
如下所示:
然后,假设我们现在的时间序列为3,有如下公式存在:
现在假设我们只是使用t=3时刻的输出去训练模型,同时使用MSE作为损失函数,那么我们在t=3时刻,损失函数就是:
求偏导的时候,就是这样的情况:
其实看到这里,答案已经出来了。
我们以第二个公式为例,也就是对$w_{x}$ 求偏导,如果时间序列程度为t,我们简化一下成下面这个公式:
时间序列越长,出现连乘的部分越集中出现在靠后面的公式上,比如$a_{t}$,但是前面的公式是不受影响的,比如$a_{1}$,也就是梯度是肯定存在的。
总结一下:RNN中的梯度消失和普通网络梯度消失含义不同,它的真实含义是远距离的梯度消失,而近距离梯度不会消失,同时总的梯度不会消失,从而导致总的梯度被近的梯度主导。