attention pytorch实现学习

attention pytorch实现学习

关于global attention概述见:

//www.greatytc.com/p/841557506ab5

本文基于《dive into deep learning》-pytorch

attention原理图(图源《dive into deep learning》).png

Additive Attention

addtive attention.png

如果key和query是不同长度的向量,一般方法是,将两者拼接起来,然后过一个线性层。
这也是常用的concat attention方法

公式也可以写成

a(q,k)= w_v^T tanh(W[q,k])

W_qq+W_kk = [W_q,W_k]\begin{bmatrix} q\\k \end{bmatrix}

实现方式

class AdditiveAttention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):#转换为num_hiddens维度,词向量长度
        #假设:query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4) 
        #batch seq word_embedding,  key和value seq_len是一样的,query是一个单独的向量,1×20
        super(AdditiveAttention,self).__init__(**kwargs)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
        self.w_v=nn.Linear(num_hiddens,1,bias=False)#
        self.dropout=nn.Dropout(dropout)
    def forward(self,queries,keys,values,valid_lens):
        queries,keys = self.W_q(queries),self.W_k(keys)#映射到相同维度 [2,1,8] [2,10,8]
        #query增加一个维度为了方便和key相加。key增加一个维度后面需要    
        features = queries.unsqueeze(2)+keys.unsqueeze(1) #torch.Size([2, 1, 1, 8]) torch.Size([2, 1, 10, 8])
        print(queries.unsqueeze(2).shape,keys.unsqueeze(1).shape)
        print(features.shape)#torch.Size([2, 1, 10, 8])
        features = torch.tanh(features)
        
        scores = self.w_v(features)#8 *1
        print(scores.shape)#torch.Size([2, 1, 10, 1])
        scores=scores.squeeze(-1)# w_v消掉最后隐藏层维,因此把这一维去掉,这里就得到了
        print(scores.shape)# 2,1,10  把seq中不需要的部分隐藏掉
        
        self.attention_weigths = masked_softmax(scores,valid_lens)#结果取softmax
        print(self.attention_weigths)
        print(self.attention_weigths.shape)#2,1,10
        # attention weights和values加权相加
        return torch.bmm(self.dropout(self.attention_weigths),values)#2,1,10 2*10*4 ->2*1*4,10个value的权重加和

pytorch 知识

1.bmm

计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h),注意两个tensor的维度必须为3. https://blog.csdn.net/qq_40178291/article/details/100302375

2.None

[:,None]

None表示该维不进行切片,而是将该维整体作为数组元素处理。

所以,[:,None]的效果就是将二维数组按每行分割,最后形成一个三维数组

torch.tensor([2,3])[:, None]
tensor([[2],
        [3]])
print(torch.arange((10), dtype=torch.float32)[None, :] )
print(torch.tensor([2,3])[:, None])
print(torch.arange((10), dtype=torch.float32)[None, :]<torch.tensor([2,3])[:, None])
------------------------------
tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
tensor([[2],
        [3]])
tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False]]
repeat_interleave repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None) 参数说明: self: 传入的数据为tensor repeats: 复制的份数 dim: 要复制的维度,可设定为0/1/2.....

sequece_mask

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    #X size=2,10
    maxlen = X.size(1)#10
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None] #index比大小,比这个index小的,都保留为true

    X[~mask] = value#则不保留的部分赋值为value。
    return X

mask_softmax

用于去除不需要的padding部分,mask部分的attention score可以忽视。

def masked_softmax(X,valid_lens):  
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens = torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens=valid_lens.reshape(-1)
        X=sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)#2,1,10转换成2,10mask,value复制一个极小值
        return nn.functional.softmax(X.reshape(shape),dim=-1)#mask后再softmax

Additive Attention

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
attention.eval()
res =attention(queries, keys, values, valid_lens)
#weight
attention.attention_weigths
#torch.Size([2, 1, 10])# 10个值代表十个weight

show_heatmap

show_heatmaps(attention.attention_weigths.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')
# 2×10  
#query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4) 
#batch1 query 和10个key的交互值
#batch2 query 和10个key的交互值
import torch
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import random
def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):
    display.set_matplotlib_formats('svg')
    num_rows,num_cols=matrices.shape[0],matrices.shape[1]
    print(num_rows,num_cols)
    fig,axes=plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,sharey=True,squeeze=False)#sharex,sharey共享x,y axes,返回各个子图
    for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):#数据
        for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):
            print(i,j)
            pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)
            if i==num_rows-1:
                ax.set_xlabel(xlabel)
            if j==0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm,ax=axes,shrink=0.6)

Scaled Dot-Product Attention

如果query和key的维度相同,可以用点乘注意力。


Scaled Dot-Product Attention 《dive into DL》.png
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    # Shape of `queries`: (`batch_size`, no. of queries, `d`) 
    # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
    # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set `transpose_b=True` to swap the last two dimensions of `keys`
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
import math
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 6))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

在学习文本分类中,我们对余弦相似度会比较熟悉,余弦相似度是用夹角来衡量向量之间相似性的一种方法,

公式://www.greatytc.com/p/a894ebba4a1a

余弦相似度.png

某种程度上,和这种注意力的公式有相似之处,也可以认为点积式注意力机制以这种方法衡量相似度。

点积attention有一个scaled的操作,这个操作的原因可以参考:

https://www.zhihu.com/question/339723385 transformer中的attention为什么scaled?

总结:

数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签

也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难

transformer中的attention为什么scaled? - TniL的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/782509914

attention知识.png

transformer中的attention为什么scaled? - 小莲子的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/811341890

即如果不scale的话,容易造成梯度消失,给参数更新造成困难。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,525评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,203评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,862评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,728评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,743评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,590评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,330评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,244评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,693评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,885评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,001评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,723评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,343评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,919评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,042评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,191评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,955评论 2 355

推荐阅读更多精彩内容