TinyBERT理论和源码解析

1 小而精的优秀

2019年12月4日, 华为诺亚方舟实验室的 NLP 团队开源 了TinyBERT(Distilling BERT for Natural Language UnderStanding)模型。此模型,比BERT-base小7.5倍,推理速度快9.4倍。优秀!

github地址:

https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT

2 两段式学习

img

图 两段式学习图示

2.1 General Distillation 通用蒸馏

原始BERT(不进行微调)作为teacher网络, TinyBERT 作为 student 网络,使用大规模语料库作为训练数据, 通过在通用领域文本上执行 Transformer 蒸馏, 提高TinyBERT的泛化能力 。 由于词向量维度的减小,隐层神经元的减少,以及网络层数的减少,tinybert的表现远不如teacher bert。

由于训练数据为大规模预料库,训练时对机器性能要求比较高,所以我们可直接下载预训练好的通用tinybert模型。官网提供了英文的预训练模型。

2.2 Task-specific Distillation 特定任务蒸馏

使用具体任务的数据,帮助TinyBERT学习到更多任务相关的具体知识 。微调的bert作为teacher网络, 微调的tinybert作为student网络。

img

2.2.1数据增强

数据增强在图像处理里比较常见,比如裁剪,反转,缩放等。TinyBert提出了一种文本处理的数据增强方法。

使用data_augmentation.py 运行数据扩充和扩充数据集,结果会自动保存到${GLUE_DIR/TASK_NAME}$/train_aug.tsv

python data_augmentation.py --pretrained_bert_model ${BERT_BASE_DIR}$ \
                            --glove_embs ${GLOVE_EMB}$ \
                            --glue_dir ${GLUE_DIR}$ \  
                            --task_name ${TASK_NAME}$

数据增强的流程:


image.png

2.2.2两个子蒸馏任务

任务1:Transformer层蒸馏
img

执行命令:

# ${FT_BERT_BASE_DIR}$ contains the fine-tuned BERT-base model.

python task_distill.py --teacher_model ${FT_BERT_BASE_DIR}$ \
                       --student_model ${GENERAL_TINYBERT_DIR}$ \
                       --data_dir ${TASK_DIR}$ \
                       --task_name ${TASK_NAME}$ \ 
                       --output_dir ${TMP_TINYBERT_DIR}$ \
                       --max_seq_length 128 \
                       --train_batch_size 32 \
                       --num_train_epochs 10 \
                       --aug_train \
                       --do_lower_case  
                         

TinyBERT 的 transformer 蒸馏采用隔 k 层蒸馏的方式。举个例子,teacher BERT 一共有 12 层,若是设置 student BERT 为 4 层,就是每隔 3 层计算一个 transformer loss. 映射函数为 g(m) = 3 * m, m 为 student encoder 层数。具体对应为 student 第 1 层 transformer 对应 teacher 第 3 层,第 2 层对应第 6 层,第 3 层对应第 9 层,第 4 层对应第 12 层。每一层的 transformer loss 又分为两部分组成,attention based distillation 和 hidden states based distillation.

关键代码分析:

BERT

词(中文字)编码

word embedding = token embeding + position embeddings + token type embeddings

class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

BertLayer(Transformer的一次encoder过程):

image.png

图 一层BertLayer的过程

class BertLayer(nn.Module):
   def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        
    def forward(self, hidden_states, attention_mask):
         # Multi-Head Attention多头注意力层,其中layer_att为Multi-Head Attention的结果, 
         attention_output, layer_att = self.attention(
            hidden_states, attention_mask)
         # Feed Forward 前向运算层
         intermediate_output = self.intermediate(attention_output)
         # resnet结构 (Multi-Head Attention 的输出 + Feed Forward的输出 )
         layer_output = self.output(intermediate_output, attention_output)
         # layer_output为总的输出结果,layer_att 为 attention_scores
         return layer_output, layer_att

注:其中layer_attr为下列公式的计算结果
QK^T/ \sqrt d_k

num_hidden_layers个BertLayer

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.layer = nn.ModuleList([BertLayer(config)
                                    for _ in range(config.num_hidden_layers)])
    # hidden_states 为 embedding的结果
    def forward(self, hidden_states, attention_mask):
        all_encoder_layers = []
        all_encoder_atts = []
        for _, layer_module in enumerate(self.layer):
            # 记录每一次的 BertLayer 的输入hidden_states(第一次为word embedding的结果,下一次的输入为上一层的输出), 和此次的layer_att
            all_encoder_layers.append(hidden_states)
            hidden_states, layer_att = layer_module(hidden_states, attention_mask)
            all_encoder_atts.append(layer_att)
            
        # 记录最后一层BertLayer的输出
        all_encoder_layers.append(hidden_states)
        return all_encoder_layers, all_encoder_atts
        
class BertModel(BertPreTrainedModel): 
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                output_all_encoded_layers=True, output_att=True):

        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        # word embedding 词(字)编码
        embedding_output = self.embeddings(input_ids, token_type_ids)
        # hidden layer次 Bert Layer
        encoded_layers, layer_atts = self.encoder(embedding_output,
                                                  extended_attention_mask)
        # 取第一个的输出
        pooled_output = self.pooler(encoded_layers)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        if not output_att:
            return encoded_layers, pooled_output

        return encoded_layers, layer_atts, pooled_output
        

我们先看下student model的定义

class TinyBertForSequenceClassification(BertPreTrainedModel):    
    def __init__(self, config, num_labels, fit_size=768):    
        super(TinyBertForSequenceClassification, self).__init__(config)        
        self.num_labels = num_labels        
        self.bert = BertModel(config)        
        self.dropout = nn.Dropout(config.hidden_dropout_prob)       
        self.classifier = nn.Linear(config.hidden_size, num_labels)        
        self.fit_dense = nn.Linear(config.hidden_size, fit_size) 
        self.apply(self.init_bert_weights)    
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,                labels=None, is_student=False):        
        sequence_output, att_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True, output_att=True)        
        logits = self.classifier(torch.relu(pooled_output))
        tmp = [] 
        if is_student:   
            for s_id, sequence_layer in enumerate(sequence_output)
                tmp.append(self.fit_dense(sequence_layer))            
                sequence_output = tmp        
        return logits, att_output, sequence_output
def main():
    # teacher model 定义
    teacher_model = TinyBertForSequenceClassification.from_pretrained(args.teacher_model, num_labels=num_labels)
    # student model 定义
    student_model = TinyBertForSequenceClassification.from_pretrained(args.student_model, num_labels=num_labels)
    
    # student model 前向运行
    student_logits, student_atts, student_reps = student_model(input_ids, segment_ids, input_mask, is_student=True)
    # teacher model 前向运行
    teacher_logits, teacher_atts, teacher_reps = teacher_model(input_ids, segment_ids, input_mask)
    
    # 得到student model 层数(如12层) 和 teacher model的层数(如4层)
    teacher_layer_num = len(teacher_atts)
    student_layer_num = len(student_atts)
    assert teacher_layer_num % student_layer_num == 0
    # 间隔层数(12/4=3)
    layers_per_block = int(teacher_layer_num / student_layer_num)
    #每间隔layers_per_block 3提取一次teacher的参数,0-->2, 1-->5, 2-->8, 3-->11 (最后一层)
    new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                                        for i in range(student_layer_num)]
    for student_att, teacher_att in zip(student_atts, new_teacher_atts):
        #  如果student_att小于-1e2置为0
        student_att = torch.where(student_att <= -1e2,              torch.zeros_like(student_att).to(device), student_att)
        teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att)
        # 计算stt的loss
        tmp_loss = loss_mse(student_att, teacher_att)
        # loss 累加
        att_loss += tmp_loss
    # 计算reps的loss,每间隔layers_per_block 3提取一次teacher的参数提取逻辑和上面不同,0--->0, 1-->3(word embeeding),2-->6, n-->3n(最后一层))
    new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
    new_student_reps = student_reps
    for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
        tmp_loss = loss_mse(student_rep, teacher_rep)
        rep_loss += tmp_loss

        loss = rep_loss + att_loss
        tr_att_loss += att_loss.item()
        tr_rep_loss += rep_loss.item()

image.png

图:映射关系图

任务2:预测层蒸馏

执行命令:

python task_distill.py --pred_distill  \
                       --teacher_model ${FT_BERT_BASE_DIR}$ \
                       --student_model ${TMP_TINYBERT_DIR}$ \
                       --data_dir ${TASK_DIR}$ \
                       --task_name ${TASK_NAME}$ \
                       --output_dir ${TINYBERT_DIR}$ \
                       --aug_train  \  
                       --do_lower_case \
                       --learning_rate 3e-5  \
                       --num_train_epochs  3  \
                       --eval_step 100 \
                       --max_seq_length 128 \
                       --train_batch_size 32 

关键代码分析:

if output_mode == "classification":   
    # 分类问题,交叉熵损失函数
    cls_loss = soft_cross_entropy(student_logits / args.temperature,                                  teacher_logits / args.temperature)
    # 回归问题,计算MSE
    elif output_mode == "regression":    
        loss_mse = MSELoss()    
        cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1))
        loss = cls_losstr_cls_loss += cls_loss.item()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,718评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,683评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,207评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,755评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,862评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,050评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,136评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,882评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,330评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,651评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,789评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,477评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,135评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,864评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,099评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,598评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,697评论 2 351

推荐阅读更多精彩内容