好了接着我们解析ATTN,这里的ATTN是注意力机制。
加了注意力机制,我们可能会得到一个更好的效果
这里我们把ATTN的代码单独搬过来
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))
def forward(self, hidden, encoder_outputs):
# hidden [1, 64, 512], encoder_outputs [14, 64, 512]
max_len = encoder_outputs.size(0)
batch_size = encoder_outputs.size(1)
# Create variable to store attention energies
attn_energies = torch.zeros(batch_size, max_len) # B x S
attn_energies = attn_energies.to(device)
# For each batch of encoder outputs
for b in range(batch_size):
# Calculate energy for each encoder output
for i in range(max_len):
attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
# Normalize energies to weights in range 0 to 1, resize to 1 x B x S
return F.softmax(attn_energies, dim=1).unsqueeze(1)
def score(self, hidden, encoder_output):
# hidden [1, 512], encoder_output [1, 512]
if self.method == 'dot':
energy = hidden.squeeze(0).dot(encoder_output.squeeze(0))
return energy
elif self.method == 'general':
energy = self.attn(encoder_output)
energy = hidden.squeeze(0).dot(energy.squeeze(0))
return energy
elif self.method == 'concat':
energy = self.attn(torch.cat((hidden, encoder_output), 1))
energy = self.v.squeeze(0).dot(energy.squeeze(0))
return energy
同样的第一步初始化,hidden_size和method
method,判断ATTN使用哪个方法的
如果method == 'general' 则,
self.attn = nn.Linear(self.hidden_size, hidden_size)
如果method == 'concat'则,
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))
Linear全连接
parameter类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型
第二步forward方法
传入参数 hidden,encoder_outputs
好了又到猜谜的时间了,encoder_outputs是什么东西呢?
哈哈,真简单,自然是encoder的输出啊,人家的名字就这么写的。
哦豁!原来如此!
首先呢,提取了encoder_outputs的第size0 作为最大长度max_len
接着,提取encoder_outputs的第size1 获取batch_size
另外作者注释 # hidden [1, 64, 512], encoder_outputs [14, 64, 512]
这个是该作者他的seq的长度,但是我们的数据和他不一样,具体还得看情况。
attn_energies = torch.zeros(batch_size, max_len)
创建了一个tensor.
attn_energies = attn_energies.to(device)
指定训练设备
接着循环,双层for循环
for b in range(batch_size):
# Calculate energy for each encoder output
for i in range(max_len):
attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
这个score方法在下面。
我们先看score,
score(self, hidden, encoder_output)
参数是 hidden, encoder_output
维度分别是 hidden [1, 512], encoder_output [1, 512]
if self.method == 'dot'
energy = hidden.squeeze(0).dot(encoder_output.squeeze(0))
什么意思,假设 self.method == 'dot‘
则 两个一维tensor点乘。
elif self.method == 'general':
energy = self.attn(encoder_output)
energy = hidden.squeeze(0).dot(energy.squeeze(0))
return energy
什么意思,假设 method == 'general'
则对 encoder_output 全连接处理,见之前init函数初始化attn
之后将全连接的 energy与处理为一维的hidded点乘
elif self.method == 'concat':
energy = self.attn(torch.cat((hidden, encoder_output), 1))
energy = self.v.squeeze(0).dot(energy.squeeze(0))
return energy
什么意思,假设method == 'concat':
self.attn(torch.cat((hidden, encoder_output), 1)
里面有个cat,这个我们知道猫嘛,拿一只猫吃掉两只耗子,哎嘿!
就是这样,的,才怪,好吧,不开玩笑。
torch.cat是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接,联系在一起
于是呼,我们的hidden和encoder_output就合成了一个tensor,长度自然是他们加起来的长。多长,作者的是512+512=1024
比对该method=='concat'的时候attn就能知道为啥这时候的attn的全连接的第一位置参数为什么是hidden_size*2了,其实就是为了匹配这个
之后就是将该attn中创建的v点乘energy,作为新的energy
我们再回来看这个循环
for b in range(batch_size):
for i in range(max_len):
attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
# Normalize energies to weights in range 0 to 1, resize to 1 x B x S
这里的第一层是批次大小,即多少个句子,第二层是最大长度,句子的最大长度。
这里将每个attn_energies中的每个值都进行了score处理。
hidden[:, b]代表该批次的所有hidden
encoder_outputs[i, b]代表什么?你说呢!