天天看点

LSTM网络的反向传播数学公式的严格矩阵推导证明

本文主要是针对LSTM网络的反向传播公式进行推导,在观看前,请确保自己拥有矩阵求导,向量求导,矩阵求导布局,链式求导法则以及LSTM网络的相关知识。因本人水平有限,如有错误请大家指出。当然,如果不想看推导过程的话,可以直接使用推导结果就行,本人已经用matlab试验成功,对我自己的数据集有95%以上的识别正确率。本科的时候用CNN识别这个数据集只有91%左右,可以说效果还行了。

圆圈中带一个点的符号是矩阵或者向量点乘。

LSTM网络的反向传播数学公式的严格矩阵推导证明

这是LSTM前向传播的算法回顾。反向传播算法首先要定义c,h的反向传播误差量,上标的t代表的是第t个时间步。h是隐藏层,c是LSTM存贮长期信息的路径。

LSTM网络的反向传播数学公式的严格矩阵推导证明

我们首先要推导每一个时间步的h和c的反向传播误差值。因为只有知道每一层的这两个值,才能更新梯度,从而运用梯度下降法。上图先计算输出层的h和c的反向传播值,再计算每一时间步的h的反向传播误差值。

LSTM网络的反向传播数学公式的严格矩阵推导证明
LSTM网络的反向传播数学公式的严格矩阵推导证明

我们用每一个时间步的h的反向传播值计算c的每一个时间步的反向传播值。

LSTM网络的反向传播数学公式的严格矩阵推导证明

接下来就是更新各个时间步的权重值和偏差值。首先是遗忘门。

LSTM网络的反向传播数学公式的严格矩阵推导证明

然后是输入门

LSTM网络的反向传播数学公式的严格矩阵推导证明
LSTM网络的反向传播数学公式的严格矩阵推导证明

最后是输出门

LSTM网络的反向传播数学公式的严格矩阵推导证明

这就是LSTM网络反向传播的全部数学推导,比较复杂,但是是一个基础性的公式推导。只要掌握了这一套流程,任何RNN,DNN或者CNN类型的神经网络甚至加上注意力机制的网络以及变体都可以自己推导。

继续阅读