1 小而精的优秀
2019年12月4日, 华为诺亚方舟实验室的 NLP 团队开源 了TinyBERT(Distilling BERT for Natural Language UnderStanding)模型。此模型,比BERT-base小7.5倍,推理速度快9.4倍。优秀!
github地址:
https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT
2 两段式学习
图 两段式学习图示
2.1 General Distillation 通用蒸馏
原始BERT(不进行微调)作为teacher网络, TinyBERT 作为 student 网络,使用大规模语料库作为训练数据, 通过在通用领域文本上执行 Transformer 蒸馏, 提高TinyBERT的泛化能力 。 由于词向量维度的减小,隐层神经元的减少,以及网络层数的减少,tinybert的表现远不如teacher bert。
由于训练数据为大规模预料库,训练时对机器性能要求比较高,所以我们可直接下载预训练好的通用tinybert模型。官网提供了英文的预训练模型。
2.2 Task-specific Distillation 特定任务蒸馏
使用具体任务的数据,帮助TinyBERT学习到更多任务相关的具体知识 。微调的bert作为teacher网络, 微调的tinybert作为student网络。
2.2.1数据增强
数据增强在图像处理里比较常见,比如裁剪,反转,缩放等。TinyBert提出了一种文本处理的数据增强方法。
使用data_augmentation.py
运行数据扩充和扩充数据集,结果会自动保存到${GLUE_DIR/TASK_NAME}$/train_aug.tsv
python data_augmentation.py --pretrained_bert_model ${BERT_BASE_DIR}$ \
--glove_embs ${GLOVE_EMB}$ \
--glue_dir ${GLUE_DIR}$ \
--task_name ${TASK_NAME}$
数据增强的流程:
2.2.2两个子蒸馏任务
任务1:Transformer层蒸馏
执行命令:
# ${FT_BERT_BASE_DIR}$ contains the fine-tuned BERT-base model.
python task_distill.py --teacher_model ${FT_BERT_BASE_DIR}$ \
--student_model ${GENERAL_TINYBERT_DIR}$ \
--data_dir ${TASK_DIR}$ \
--task_name ${TASK_NAME}$ \
--output_dir ${TMP_TINYBERT_DIR}$ \
--max_seq_length 128 \
--train_batch_size 32 \
--num_train_epochs 10 \
--aug_train \
--do_lower_case
TinyBERT 的 transformer 蒸馏采用隔 k 层蒸馏的方式。举个例子,teacher BERT 一共有 12 层,若是设置 student BERT 为 4 层,就是每隔 3 层计算一个 transformer loss. 映射函数为 g(m) = 3 * m, m 为 student encoder 层数。具体对应为 student 第 1 层 transformer 对应 teacher 第 3 层,第 2 层对应第 6 层,第 3 层对应第 9 层,第 4 层对应第 12 层。每一层的 transformer loss 又分为两部分组成,attention based distillation 和 hidden states based distillation.
关键代码分析:
BERT
词(中文字)编码
word embedding = token embeding + position embeddings + token type embeddings
class BertEmbeddings(nn.Module):
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
BertLayer(Transformer的一次encoder过程):
图 一层BertLayer的过程
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
# Multi-Head Attention多头注意力层,其中layer_att为Multi-Head Attention的结果,
attention_output, layer_att = self.attention(
hidden_states, attention_mask)
# Feed Forward 前向运算层
intermediate_output = self.intermediate(attention_output)
# resnet结构 (Multi-Head Attention 的输出 + Feed Forward的输出 )
layer_output = self.output(intermediate_output, attention_output)
# layer_output为总的输出结果,layer_att 为 attention_scores
return layer_output, layer_att
注:其中layer_attr为下列公式的计算结果
num_hidden_layers个BertLayer
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.layer = nn.ModuleList([BertLayer(config)
for _ in range(config.num_hidden_layers)])
# hidden_states 为 embedding的结果
def forward(self, hidden_states, attention_mask):
all_encoder_layers = []
all_encoder_atts = []
for _, layer_module in enumerate(self.layer):
# 记录每一次的 BertLayer 的输入hidden_states(第一次为word embedding的结果,下一次的输入为上一层的输出), 和此次的layer_att
all_encoder_layers.append(hidden_states)
hidden_states, layer_att = layer_module(hidden_states, attention_mask)
all_encoder_atts.append(layer_att)
# 记录最后一层BertLayer的输出
all_encoder_layers.append(hidden_states)
return all_encoder_layers, all_encoder_atts
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
output_all_encoded_layers=True, output_att=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# word embedding 词(字)编码
embedding_output = self.embeddings(input_ids, token_type_ids)
# hidden layer次 Bert Layer
encoded_layers, layer_atts = self.encoder(embedding_output,
extended_attention_mask)
# 取第一个的输出
pooled_output = self.pooler(encoded_layers)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if not output_att:
return encoded_layers, pooled_output
return encoded_layers, layer_atts, pooled_output
我们先看下student model的定义
class TinyBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config, num_labels, fit_size=768):
super(TinyBertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.fit_dense = nn.Linear(config.hidden_size, fit_size)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, is_student=False):
sequence_output, att_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True, output_att=True)
logits = self.classifier(torch.relu(pooled_output))
tmp = []
if is_student:
for s_id, sequence_layer in enumerate(sequence_output)
tmp.append(self.fit_dense(sequence_layer))
sequence_output = tmp
return logits, att_output, sequence_output
def main():
# teacher model 定义
teacher_model = TinyBertForSequenceClassification.from_pretrained(args.teacher_model, num_labels=num_labels)
# student model 定义
student_model = TinyBertForSequenceClassification.from_pretrained(args.student_model, num_labels=num_labels)
# student model 前向运行
student_logits, student_atts, student_reps = student_model(input_ids, segment_ids, input_mask, is_student=True)
# teacher model 前向运行
teacher_logits, teacher_atts, teacher_reps = teacher_model(input_ids, segment_ids, input_mask)
# 得到student model 层数(如12层) 和 teacher model的层数(如4层)
teacher_layer_num = len(teacher_atts)
student_layer_num = len(student_atts)
assert teacher_layer_num % student_layer_num == 0
# 间隔层数(12/4=3)
layers_per_block = int(teacher_layer_num / student_layer_num)
#每间隔layers_per_block 3提取一次teacher的参数,0-->2, 1-->5, 2-->8, 3-->11 (最后一层)
new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
for i in range(student_layer_num)]
for student_att, teacher_att in zip(student_atts, new_teacher_atts):
# 如果student_att小于-1e2置为0
student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att)
teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att)
# 计算stt的loss
tmp_loss = loss_mse(student_att, teacher_att)
# loss 累加
att_loss += tmp_loss
# 计算reps的loss,每间隔layers_per_block 3提取一次teacher的参数提取逻辑和上面不同,0--->0, 1-->3(word embeeding),2-->6, n-->3n(最后一层))
new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
new_student_reps = student_reps
for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
tmp_loss = loss_mse(student_rep, teacher_rep)
rep_loss += tmp_loss
loss = rep_loss + att_loss
tr_att_loss += att_loss.item()
tr_rep_loss += rep_loss.item()
图:映射关系图
任务2:预测层蒸馏
执行命令:
python task_distill.py --pred_distill \
--teacher_model ${FT_BERT_BASE_DIR}$ \
--student_model ${TMP_TINYBERT_DIR}$ \
--data_dir ${TASK_DIR}$ \
--task_name ${TASK_NAME}$ \
--output_dir ${TINYBERT_DIR}$ \
--aug_train \
--do_lower_case \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--eval_step 100 \
--max_seq_length 128 \
--train_batch_size 32
关键代码分析:
if output_mode == "classification":
# 分类问题,交叉熵损失函数
cls_loss = soft_cross_entropy(student_logits / args.temperature, teacher_logits / args.temperature)
# 回归问题,计算MSE
elif output_mode == "regression":
loss_mse = MSELoss()
cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1))
loss = cls_losstr_cls_loss += cls_loss.item()