论文解读 | Transformer 原理深入浅出

Attention 机制由 Bengio 团队于 2014 年提出,并广泛应用在深度学习的各个领域。而 Google 提出的用于生成词向量的 Bert 在 NLP 的 11 项任务中取得了效果的大幅提升,Bert 正是基于双向 Transformer。

Transformer 是第一个完全依赖于 Self-Attention 来计算其输入和输出表示的模型,而不使用序列对齐的 RNN 或 CNN。更准确的讲,Transformer 由且仅由 self-Attention 和 Feed Forward Neural Network 组成。一个基于 Transformer 的可训练的神经网络可以通过堆叠 Transformer 的形式进行搭建,作者的实验是通过搭建编码器和解码器各 6 层,总共 12 层的 Encoder-Decoder,并在机器翻译中取得了 BLEU 值得新高。

Transformer 结构

image

解释一下上面这个结构图。Transformer 采用的也是经典的 Encoder 和 Decoder 架构,由 Encoder 和 Decoder 组成。

Encoder 的结构由 Multi-Head Self-Attention 和 position-wise feed-forward network 组成,Encoder 的输入由 Input Embedding 和 Positional Embedding 求和组成。

Decoder 的结构由 Masked Multi-Head Self-Attention,Multi-Head Self-Attention 和 position-wise feed-forward network 组成。Decoder 的初始输入由 Output Embedding 和 Positional Embedding 求和得到。

上图左半边 Nx 框出来的部分是 Encoder 的一层,Transformer 中 Encoder 有 6 层。

上图右半边 Nx 框出来的部分是 Decoder 的一层,Transformer 中 Decoder 有 6 层。

image

Encoder

image

Encoder 由 6 个相同的层组成,每个层包含 2 个部分:

  • Multi-Head Self-Attention
  • Position-Wise Feed-Forward Network (全连接层)

两个部分都有残差连接 (redidual connection),然后接一个 Layer Normalization。

Encoder 的输入由 Input Embedding 和 Positional Embedding 求和组成。

如果你是刚开始学 Transformer,你可能会问:

  • Multi-Head Self-Attention 是什么?
  • 残差连接 (redidual connection) 是什么?
  • Layer Normalization 是什么?

后面都会一一解答,请往后看。

Decoder

image

和 Encoder 相似,Decoder 也是由 6 个相同的层组成,每个层包含 3 个部分:

  • Multi-Head Self-Attention
  • Multi-Head Context-Attention
  • Position-Wise Feed-Forward Network

上面三个部分都有残差连接 (redidual connection),然后接一个 Layer Normalization。

Decoder 多了个 Multi-Head Context-Attention,如果理解了 Multi-Head Self-Attention,这个就很好理解了,后面会提到这两个 Attention。

Self-Attention 机制

Attention 常用的有两种,一种是加性注意力(Additive Attention),另一组是点乘注意力(Dot-product Attention),论文采用的是点乘注意力,这种注意力机制相比加法注意力机制,更快,同时更省空间。

Self-Attention 是 Transformer 的核心内容,然而作者并没用详细讲解。

以下面这句话为例,作为我们翻译的输入语句,我们可以看下 Attention 如何对这句话进行表示。

The animal didn’t cross the street because it was too tired

我们可以思考一个问题,“it” 指代什么?是 “street” 还是 “animal” ? 对人来说,很容易就能知道是 “animal”,但是对于算法来说,并没有这么简单。

模型处理单词 “it” 时,Attention 允许将 “it” 和 “animal” 联系起来。当模型处理每个位置时,Attention 对不同位置产生不同的注意力,使其来更好的编码当前位置的词,如果你熟悉 RNN,就知道 RNN 如何根据之前的隐状态信息来编码当前词。

即:当编码 “it” 时,部分 Attention 集中于 “the animal”,并将其表示合并到 “it” 的编码中。

image

RNN 要逐步递归才能获取全局信息,因此一般要双向 RNN 才比较好,且下一时刻信息要依赖于前面时刻的信息。CNN 只能获取局部信息,是通过叠层来增大感受野,Attention 思路最为粗暴,一步到位获得了全局信息。

而 Transformer 使用 Self-Attention,简单的解释:通过确定Q和K之间的相似程度来选择V

使用 Self-Attention 有几个好处:

  • 每一层的复杂度小
    • 如果输入序列 n 小于表示维度 d 的话,Self-Attention 的每一层时间复杂度有优势。
    • 当 n 比较大时,作者也给出了解决方案,Self-Attention 中每个词不是和所有词计算 Attention,而是只与限制的 r 个词进行 Attention 计算。
  • 并行 Multi-Head Attention 和 CNN 一样不依赖前一时刻的计算,可以很好的并行,优于 RNN。
  • 长距离依赖 优于 Self-Attention 是每个词和所有词计算 Attention,所以不管他们中间有多长距离,最大路径长度都只是 1,可以捕获长距离依赖关系。

上面讲到 Decoder 中有两种 Attention,一种是 Self-Attention,一种是 Context-Attention。

Context-Attention 也就是 Encoder 和 Decoder 之间的 Attention,也可以称之为 Encoder-Decoder Attention。

无论是Self-Attention 还是 Context-Attention,它们在计算 Attention 分数的时候,可以有很多选择:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那么我们的Transformer模型,采用的是哪种呢?答案是:scaled dot-product attention

image

为什么要加这个缩放因子呢?论文里给出了解释:如果 dk 很小,加性注意力和点乘注意力相差不大,但是如果 dk 很大,点乘得到的值很大,如果不做 scaling,结果就没有加性注意力好,另外,点乘结果过大,使得经过 softmax 之后的梯度很小,不利于反向传播的进行,所以我们通过对点乘的结果进行scaling。

image

先简单说下 Q、K、V 是什么:

  • Encoder 的 Self-Attention 中,Q、K、V 都来自同一个地方(相等),他们是上一层 Encoder 的输出,对于第一层 Encoder,他们就是 Word Embedding 和 Positional Embedding 相加得到的输入。
  • Decoder 的 Self-Attention 中,Q、K、V都来自于同一个地方(相等),它们是上一层 Decoder 的输出,对于第一层 Decoder,他们就是 Word Embedding 和 Positional Embedding 相加得到的输入。但是对于 Decoder,我们不希望它能获得下一个 time step(将来的信息),因此我们需要进行 Sequence masking。
  • 在 Encoder-Decoder Attention 中,Q 来自于上一层 Decoder 的输出,K 和 V 来自于 Encoder 的输出,K 和 V 是一样的。

Multi-Head Attention

论文提出,由于不同的 Attention 的权重侧重点不一样,所以将这个任务交给不同的 Attention 一起做,最后取综合结果会更好,有点像 CNN 中的 Keynel。

文章表示,将 Q、K、V 通过一个线性映射后,分成 h 份,对没分进行 Scaled Dot-Product Attention 效果更好, 再把这几个部分 Concat 起来,过一个线性层的效果更好,可以综合不同位置的不同表征子空间的信息

image

论文里面,d_{model}=512,h=8。所以在scaled dot-product attention里面的

d_q=d_k=d_v=d_{model}/h=512/8=64

Residual connection 残差连接

在了解残差网络之前,先思考下面的问题:

  • 神经网络越深越好吗?

下图中显示,传统神经网络越深效果不一定好。而 Deep Residual Learning for Image Recognition 这篇论文认为,理论上,可以训练一个浅层网络,然后再这个训练好的浅层网络上堆几层恒等映射层,即输出等于输入层,构建一个深层网络。浅层网络和深层网络得到的结果一模一样,因为堆上去的层是恒等变换的。

这样就可以得出一个结论:理论上,在训练集上,深层网络不会比浅层网络差。但是为什么出现下面这种情况呢?随着层数增加,训练集上效果反而变差,这被称为退化问题。原因是随着网络越来越深,训练和优化变得越来越难,过深的网络会产生退化问题,效果反而不如相对较浅的网络。而餐内存网络可以解决这个问题,残差网络月神,训练集上效果越好。

image

残差网络通过加入 shortcut connections,变得更加容易被优化。包含一个 shortcut connection 的几层网络被称为一个残差块(residual block)。残差块分成两部分直接映射部分和残差部分。

残差网络由残差块组成,一个残差块可以表示为:

x_{l+1}=x_1+F(x_1,W_1)

image

残差网络有什么好处呢?显而易见:因为增加了 x 项,那么该网络求 x 的偏导的时候,多了一项常数 1,所以反向传播过程,梯度连乘,也不会造成梯度消失。

残差网络的实现非常简单:

def residual(sublayer_fn,x):
        return sublayer_fn(x)+x

Layer normalization

Normalization 有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为 0 方差为 1 的数据。我们在把数据送入激活函数之前进行 Normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区。

随着训练的进行,网络中的参数也随着梯度下降在不停更新。

  • 一方面,当底层网络中参数发生微弱变化时,由于每一层中的线性变换与非线性激活映射,这些微弱变化随着网络层数的加深而被放大(类似蝴蝶效应)。
  • 另一方面,参数的变化导致每一层的输入分布会发生改变,进而上层的网络需要不停地去适应这些分布变化,使得我们的模型训练变得困难。上述这一现象叫做 Internal Covariate Shift。

BN 的作者给 Internal Covariate Shift 的定义为:在深层网络训练过程中,由于网络中参数变化而引起内部节点数据分布发生变化的这一过程被称作 Internal Covariate Shift。

BN 就是为了解决这一问题,一方面可以简化计算过程,一方面经过规范化处理后让数据尽可能保留原始表达能力。

BN 的主要思想是:在每一层的每一批数据上进行归一化。

image

说完 Batch Normalization,就该说说咱们今天的主角 Layer normalization

那么什么是 Layer Normalization 呢?它也是归一化数据的一种方式,不过 LN 是在每一个样本上计算均值和方差,而不是 BN 那种在批方向计算均值和方差

image

Mask

现在终于轮到讲解 Mask 了! 大概就是对某些值进行掩盖,使其不产生效果。

Transformer 模型里面涉及两种 Mask。分别是 Padding Mask 和 Sequence Mask。

其中,Padding Mask 在所有的 Scaled Dot-Product Attention 里面都需要用到,而 Sequence Mask 只有在 Decoder 的 Self-Attention 里面用到。

所以,我们之前 Scaled Dot-Product Attention 的 forward 方法里面的参数 attn_mask 在不同的地方会有不同的含义。

Padding Mask

什么是 Padding Mask 呢?回想一下,我们的每个批次输入序列长度是不一样的。我们要对输入序列进行对齐!就是给在较短的序列后面填充 0。因为这些填充的位置,其实是没什么意义的,所以我们的 Attention 机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 Softmax,这些位置的概率就会接近 0 !

而我们的 Padding Mask 实际上是一个张量,每个值都是一个 Boolen,值为 False 的地方就是我们要进行处理的地方。

def padding_mask(seq_k, seq_q):
    # seq_k 和 seq_q 的形状都是 [B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    # shape [B, L_q, L_k]
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  
    return pad_mask

Sequence mask

文章前面也提到,Sequence Mask 是为了使得 Decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。

那么具体怎么做呢?也很简单:产生一个上三角矩阵,上三角的值全为 1,下三角的权值为 0,对角线也是 0。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。

image

本来 Mask 只需要二维的矩阵即可,但是考虑到我们的输入序列都是批量的,所以我们要把原本 2 维的矩阵扩张成 3 维的张量。

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

回到本小结开始的问题,attn_mask 参数有几种情况?分别是什么意思?

  • 对于decoder的self-attention,里面使用到的scaled dot-product attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个 mask 相加作为attn_mask。
  • 其他情况,attn_mask 一律等于 padding mask。

至此,Mask 相关的问题解决了。

Positional encoding

因为 Transformer 利用 Attention 的原因,少了对序列的顺序约束,这样就无法组成有意义的语句。为了解决这个问题,Transformer 对位置信息进行编码。

image

pos 指词语在序列中的位置,偶数位置,使用正弦编码,奇数位置,使用余弦编码。

上述公式解释:给定词语的位置 pos,我们可以把它编码成 d_model 维的向量!也就是说,位置编码的每一个维度对应正弦曲线,波长构成了从 2π1000*2π 的等比序列。

上面的位置编码是绝对位置编码。但是词语的相对位置也非常重要。这就是论文为什么要使用三角函数的原因!

正弦函数能够表达相对位置信息,主要数学依据是以下两个公式:

sin(α+β)=sinαcosβ+cosαsinβ

cos(α+β)=cosαcosβ-sinαsinβ

上面的公式说明,对于词汇之间的位置偏移 k,PE(pos+k) 可以表示成 PE(pos)PE(k) 组合的形式,相当于有了可以表达相对位置的能力。

class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

我们对 Position Encoding 打印查看其位置图像:

plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward(Variable(torch.zeros(1, 100, 20)))
plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])
image

我们还尝试使用学习 Position Embedding,发现这两个版本的结果几乎相同。我们选择正弦,因为它可以处理更长序列的情况。

Position-wise Feed-Forward network

这是一个全连接网络,包含两个线性变换和一个非线性函数 (实际上就是 ReLU)。公式如下

FFN=max(0,xW_1+b_1)W_2+b_2

这个线性变换在不同的位置都表现地一样,并且在不同的层之间使用不同的参数。

这里实现上用到了两个一维卷积。

总结

比起传统的 RNN、CNN 模型,Transformer 的优势在于它在每个时刻 t 求得的隐藏向量 ht 都包含整个序列的信息 ( 其实是 Self-Attention 结构的优势,可建模出任意一个时刻 item 和所有时刻 item 的相关性 ) 。因此可将 Transformer 结构用于用户的短期兴趣 Embedding 建模,然后再将该 Embedding 向量用于召回或者 Ranking 阶段。

如果您觉得文章对您有帮助,欢迎关注个人公众号学习交流。


image
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,776评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,527评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,361评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,430评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,511评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,544评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,561评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,315评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,763评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,070评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,235评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,911评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,554评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,173评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,424评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,106评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,103评论 2 352

推荐阅读更多精彩内容