pytorch seq2seq聊天机器人 ATTN解析

好了接着我们解析ATTN,这里的ATTN是注意力机制。

加了注意力机制,我们可能会得到一个更好的效果

这里我们把ATTN的代码单独搬过来

class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()

        self.method = method
        self.hidden_size = hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):
        # hidden [1, 64, 512], encoder_outputs [14, 64, 512]
        max_len = encoder_outputs.size(0)
        batch_size = encoder_outputs.size(1)

        # Create variable to store attention energies
        attn_energies = torch.zeros(batch_size, max_len) # B x S
        attn_energies = attn_energies.to(device)

        # For each batch of encoder outputs
        for b in range(batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

def score(self, hidden, encoder_output):
        # hidden [1, 512], encoder_output [1, 512]
        if self.method == 'dot':
            energy = hidden.squeeze(0).dot(encoder_output.squeeze(0))
            return energy

        elif self.method == 'general':
            energy = self.attn(encoder_output)
            energy = hidden.squeeze(0).dot(energy.squeeze(0))
            return energy

        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = self.v.squeeze(0).dot(energy.squeeze(0))
            return energy

同样的第一步初始化,hidden_size和method

method,判断ATTN使用哪个方法的

如果method == 'general' 则,
self.attn = nn.Linear(self.hidden_size, hidden_size)

如果method == 'concat'则,
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

Linear全连接
parameter类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型

第二步forward方法
传入参数 hidden,encoder_outputs

好了又到猜谜的时间了,encoder_outputs是什么东西呢?
哈哈,真简单,自然是encoder的输出啊,人家的名字就这么写的。

哦豁!原来如此!

首先呢,提取了encoder_outputs的第size0 作为最大长度max_len

接着,提取encoder_outputs的第size1 获取batch_size

另外作者注释 # hidden [1, 64, 512], encoder_outputs [14, 64, 512]
这个是该作者他的seq的长度,但是我们的数据和他不一样,具体还得看情况。

attn_energies = torch.zeros(batch_size, max_len)
创建了一个tensor.

attn_energies = attn_energies.to(device)
指定训练设备

接着循环,双层for循环
for b in range(batch_size):
# Calculate energy for each encoder output
for i in range(max_len):
attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

这个score方法在下面。

我们先看score,
score(self, hidden, encoder_output)
参数是 hidden, encoder_output
维度分别是 hidden [1, 512], encoder_output [1, 512]

if self.method == 'dot'
energy = hidden.squeeze(0).dot(encoder_output.squeeze(0))
什么意思,假设 self.method == 'dot‘
则 两个一维tensor点乘。

elif self.method == 'general':
energy = self.attn(encoder_output)
energy = hidden.squeeze(0).dot(energy.squeeze(0))
return energy

什么意思,假设 method == 'general'
则对 encoder_output 全连接处理,见之前init函数初始化attn
之后将全连接的 energy与处理为一维的hidded点乘

elif self.method == 'concat':
energy = self.attn(torch.cat((hidden, encoder_output), 1))
energy = self.v.squeeze(0).dot(energy.squeeze(0))
return energy

什么意思,假设method == 'concat':
self.attn(torch.cat((hidden, encoder_output), 1)
里面有个cat,这个我们知道猫嘛,拿一只猫吃掉两只耗子,哎嘿!
就是这样,的,才怪,好吧,不开玩笑。
torch.cat是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接,联系在一起
于是呼,我们的hidden和encoder_output就合成了一个tensor,长度自然是他们加起来的长。多长,作者的是512+512=1024

比对该method=='concat'的时候attn就能知道为啥这时候的attn的全连接的第一位置参数为什么是hidden_size*2了,其实就是为了匹配这个

之后就是将该attn中创建的v点乘energy,作为新的energy

我们再回来看这个循环

    for b in range(batch_size):
     
        for i in range(max_len):
            attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

    # Normalize energies to weights in range 0 to 1, resize to 1 x B x S

这里的第一层是批次大小,即多少个句子,第二层是最大长度,句子的最大长度。

这里将每个attn_energies中的每个值都进行了score处理。

hidden[:, b]代表该批次的所有hidden
encoder_outputs[i, b]代表什么?你说呢!

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,332评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,508评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,812评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,607评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,728评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,919评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,071评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,802评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,256评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,576评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,712评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,389评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,032评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,026评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,473评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,606评论 2 350

推荐阅读更多精彩内容