天天看点

LSTM 入门级解读

记录学习过程,方便日后查用。本贴包括数学计算过程和模型解读。

如有错误请指出,感谢大家的指导。

图片来源 LSTM模型结构的可视化 - 知乎

LSTM 入门级解读

淡绿色的方块被称为cell,是构成LSTM的主要结构。实际上对于RNN类网络来说,都会有一个这样的结构块,在时间上循环这个结构块就构成了RNN网络。上图是最基础的LSTM网络。

LSTM的单元输入总共有3个部分 h是隐藏层,X是数据输入,C可以看成是网络的记忆部分。所有红色的单元是运算符,运算过程就是简单的套用运算符;所有黄色的单元是网络层,运算过程类似感知机,sigma符号代表的激活函数默认为sigmoid函数。

数学运算过程

LSTM的cell内部总共有3个主要的门,第一个被称为忘记门(forget gate) ,用来决定上一轮的输入能有多少影响到这一轮的输入。

忘记门公式

LSTM 入门级解读
LSTM 入门级解读

中括号表示concatenate,单纯的将两个向量进行维度上的合并,如x有100维,h有200维,那么中括号就会返回一个300维的向量。忘记门会对上一轮的输入做一个筛选,和输入门的输入一起做加法得到本轮的记忆。

输入门决定了这一轮的主要输入。

LSTM 入门级解读
LSTM 入门级解读
LSTM 入门级解读

输出门

LSTM 入门级解读
LSTM 入门级解读

至此,我们得到了本轮输出

LSTM 入门级解读

模型简单解读

LSTM能拥有长时记忆的主要原因就在于变量C,C的运算结构中包含了加法。对比传统的RNN网络只有一个tanh来说,更不容易出现梯度爆炸或者梯度消失的情况。

LSTM的参数个数计算。假设词向量的维度是m,隐藏层维度为n。

那么参数总数为((m+n)*n+n)*4。

上文提到LSTM虽然是链式结构,但是是在时间上循环同一个单元,所以cell之间所有的参数是共享的,不共享的是每个cell内4个网络层的参数。4个网络层都是感知机的模式,相当于4个全连接层。全连接层的输入维度是x和h的concatenate,输出维度是h,再加上偏置和输出维度相同,所以参数数量一共是 (m+n)*n+n,见下图。因为总共有4个这样的网络,所以再乘以4。

LSTM 入门级解读

题外话

一般在使用LSTM做文本数据时,我们关注的不仅仅只是过去的信息,可能还会有未来的信息,即“结合上下文”。所以会用上双向的LSTM,双向LSTM可以看成LSTM和他的镜像结合在一起,最终两个LSTM的隐层结合一下再输出。所以参数个数是LSTM的两倍。

LSTM还有很多的变种,但大体的结构都大差不大。

继续阅读