深入浅出RNN

RNN是深度学习在自然语言处理领域中的元老级神经网络,它奠定了神经网络技术在NLP领域的发展方向,其名声仅此于CNN,虽然近年来它已经鲜少露面,但江湖地位不减,就连当红明星GRU和LSTM都是它的变种。

RNN(Recurrent Neural Networks),循环神经网络,指的是模型循环处理每个input,每次循环迭代称为time(在公式中简称“t”)。Figure 1中,等号右侧部分就是RNN的展开图:input, i_ti_1 ~ i_4)经过隐藏层循环处理,每个time会生成一个output,o_to_1~o_4),此外还会生成一个hidden state, h_t ,它是隐藏层对input的学习成果,hidden state会和下一个input一起作为参数传入隐藏层(红色箭头)。

Figure 1: RNN网络结构

我在深入浅出全连接层中提过,RNN是由全连接层(Linear layer)组成的,准确地说,RNN中的input layer、hidden layer、output layer,就是Figure 1中的那些箭头,它们都是全连接层。

本文将会以重写RNN的方式来由浅入深剖析RNN,点击【这里】可以查看完整源码。

Framework && Dataset

除了使用Fastai Library外,还会用它的轻量级NLP数据集:http://files.fast.ai/data/examples/human_numbers.tgz。数据集是从0到9999的英文数字:“one, two, three, four, five, ......, nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine”。

我们的任务就是创建一个RNN N-gram语言模型来学习数数,比如说,看到“one, two, three, four, five, ”这10个连续token(","也是一个token),就能预测出第11个token是"six"。关于token、N-gram、语言模型、embedding以及nlp mini-batch等内容,可以回看以前的文章自己动手开发AI影评写作机器人,这里不再赘述。

Single output

class Model1(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(nv, wordvec_len)
    self.input = nn.Linear(wordvec_len, nh) # input layer
    self.hid = nn.Linear(nh, nh) # hidden layer
    self.out = nn.Linear(nh, nv) # output layer
    self.bn = nn.BatchNorm1d(nh)
  
  def forward(self, x):
    h = torch.zeros(x.shape[0], nh).to(device=x.device)
    for i in range(x.shape[1]):
      h = h + F.relu(self.input(self.emb(x[:, i])))
      h = self.bn(F.relu(self.hid(h)))
    return self.out(h)

如Figure 1所示,Model1通过一个for循环来(递归)处理每个input(token),并将hidden state传递到下一个循环,等到N个(x.shape[1])token都学习完之后,用第N个hidden state来生成一个output。这就是RNN根据前面N个token来预测第N+1个token的过程。

之所以RNN可以预测第N+1个token,是因为hidden state包含了前面N个token的信息。h = h + F.relu(self.input(self.emb(x[:, i]))),为RNN预测提供了前文的信息。

learn = Learner(data, Model1(), loss_func=loss_f, metrics=acc_f)
learn.fit_one_cycle(10, 1e-4)

epoch   train_loss  valid_loss  acc_f   time
0   3.546392    3.752568    0.024586    00:01
1   2.911883    3.326832    0.227022    00:01
2   2.289827    2.718843    0.447610    00:01
3   1.943045    2.349009    0.465303    00:01
4   1.766844    2.192134    0.466222    00:01
5   1.678379    2.125904    0.466452    00:01
6   1.633395    2.100061    0.464844    00:01
7   1.610903    2.090777    0.464614    00:01
8   1.601126    2.087984    0.459099    00:01
9   1.598013    2.088171    0.459099    00:01

Learner是Fastai提供用于模型训练的library,“acc_f”这栏显示的是模型的预测准确率:45.9%。

Maintain state

class Model3(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(nv, wordvec_len)
    self.input = nn.Linear(wordvec_len, nh)
    self.hid = nn.Linear(nh, nh)
    self.out = nn.Linear(nh, nv)
    self.bn = nn.BatchNorm1d(nh)
    self.h = torch.zeros(bs, nh).cuda()
  
  def forward(self, x):
    h = self.h
    res = []
    for i in range(x.shape[1]):
      h = h + torch.tanh(self.input(self.emb(x[:, i])))
      h = self.bn(torch.tanh(self.hid(h)))
      res.append(h)
    self.h = h.detach()
    res = torch.stack(res, dim=1)
    return self.out(res)

Model1的hidden state只保留了同一个mini-batch内的token信息,一旦开始学习新的mini-batch,hidden state就会清零:h = torch.zeros(x.shape[0], nh).to(device=x.device)

因此,Model3会保留前一个mini-batch训练得到的hidden state,换句话说,此时的hidden state保留了整个dataset的信息。不仅如此,Model3不只是预测第N+1个token,它会预测所有token的下一个token。

epoch   train_loss  valid_loss  acc_f   time
0   3.546392    3.752568    0.024586    00:01
                 ......
15  0.193776    0.834583    0.759517    00:00
16  0.179285    0.845180    0.757031    00:00
17  0.167008    0.859531    0.751491    00:00
18  0.157094    0.860083    0.754972    00:00
19  0.150660    0.854162    0.754474    00:00

Model3实现了RNN的基本功能,它的预测准确率达到72.2%。除此之外,用tanh替代relu,准确率会提升到75.4%,标准RNN的默认激活函数就是tanh。

Pytorch RNN

class Model4(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(nv, wordvec_len)
    self.input = nn.Linear(wordvec_len, nh)
    self.rnn = nn.RNN(nh, nh, 1, batch_first=True)
    self.out = nn.Linear(nh, nv)
    self.bn = BatchNorm1dFlat(nh)
    self.h = torch.zeros(1, bs, nh).cuda()
  
  def forward(self, x):
    res, h = self.rnn(self.input(self.emb(x)), self.h)
    self.h = h.detach()
    return self.out(self.bn(res))

epoch   train_loss  valid_loss  acc_f   time
0   3.546392    3.752568    0.024586    00:01
                 ......
15  0.283078    1.124835    0.662003    00:00
16  0.261476    1.067836    0.680043    00:00
17  0.244309    1.112445    0.670455    00:00
18  0.233984    1.112467    0.668253    00:00
19  0.226839    1.101279    0.671307    00:00

Model4是Model3的对标模型,用Pytorch官方提供的RNN实现,从训练结果来看,它准确率比Model3要低不少,因此,这里温馨提醒,不要把官方code当作圣旨。实践表明,relu+batchnorm的组合应用到RNN中卓有成效,可惜,它无法作用于nn.RNN。

class Model4(nn.Module):
  def __init__(self):
    ......
    self.rnn = nn.RNN(nh, nh, 2, batch_first=True, dropout=0.1)
    ......

epoch   train_loss  valid_loss  acc_f   time
0   3.546392    3.752568    0.024586    00:01
                 ......
15  0.123147    0.816822    0.789773    00:00
16  0.111867    0.815223    0.793466    00:00
17  0.102646    0.834602    0.788210    00:00
18  0.096860    0.828752    0.791761    00:00
19  0.092941    0.836145    0.790909    00:00

虽然nn.RNN的表现不如自己手撸的好,但并不是鼓励大家自己造轮子,而且本人也很反对像自己造轮子这种几乎不输出价值的工作方式。事实上,当我们使用2层RNN之后,Model4的表现就优于Model3了。

2层RNN的结构如下图所示,2个隐藏层虽然比1个隐藏层效果更好,但计算量也翻倍了,由于全连接层的参数数量比卷积层的要多得多,深层RNN的训练时间要远远长于相同深度的CNN,因此,RNN层数不宜过多。

Figure 2: 2-layer RNN

GRU

RNN虽然通过hidden state来保留各token的信息,但它并不能有效地处理和使用这些信息,它会将所有训练得来的信息都一股脑地塞进hidden state,并用它来影响后续每个token的学习,而且这种影响不管是好是坏都会随着训练的深入,像滚雪球一样有增无减地将越多越多的信息裹胁进来。

显然RNN缺乏对新老信息(hidden state和token)的调控,这样一来,当要学习的token较多时,起始部分的token信息在hidden state中的占比就会越来越小,换句话说,文本起始部分的内容很容易会被忘记。

GPU和LSTM就是专门为了应对长文本训练而设计的,它们为RNN中增加了input gate、output gate、forget gate、update gate等机制,对进出信息进行筛选,丢弃过时的、无用的信息,调控输入信息对模型训练的影响力等。

class Model5(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(nv, wordvec_len)
    self.input = nn.Linear(wordvec_len, nh)
    self.rnn = nn.GRU(nh, nh, 2, batch_first=True, dropout=0.05)
    self.out = nn.Linear(nh, nv)
    self.bn = BatchNorm1dFlat(nh)
    self.h = torch.zeros(2, bs, nh).cuda()
  
  def forward(self, x):
    res, h = self.rnn(self.input(self.emb(x)), self.h)
    self.h = h.detach()
    return self.out(self.bn(res))


epoch   train_loss  valid_loss  acc_f   time
0   3.546392    3.752568    0.024586    00:01
                 ......
5   0.123571    1.116273    0.833310    00:00
6   0.071810    1.277520    0.835156    00:00
7   0.043868    1.258988    0.838068    00:00
8   0.029036    1.265061    0.839134    00:00
9   0.021350    1.290411    0.838778    00:00

Model5用GRU替代了RNN后,在相同情况下,模型准确率又有了提升,达到了83.8%。关于GRU的分析留待下一篇博文,这里不作展开。

END

本文通过重构RNN的方式详解了RNN模型的构成,通过分析它的缺陷,进一步介绍了GRU和LSTM的工作原理。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,544评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,430评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,764评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,193评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,216评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,182评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,063评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,917评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,329评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,543评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,722评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,425评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,019评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,671评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,825评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,729评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,614评论 2 353

推荐阅读更多精彩内容