本节主要介绍Fairseq模型以及如何自定义模型。Fairseq中的模型一般是在fairseq/models中,包括常用的Transformer、BART等等,本文以BART为例讲解Fairseq模型的使用以及自定义方法。
BART模型
Fairseq中,BART包括两个文件,分别是model.py(模型内部结构)以及hub_interface.py(读入预训练参数)。
model.py
- Fairseq需要注册模型及模型框架,这样在训练时可以通过
--arch bart_large
识别到该模型及其初始化参数设置
from fairseq.models import register_model, register_model_architecture
@register_model("bart")
@register_model_architecture("bart", "bart_large")
自定义模型时同样需要注册模型以及模型框架,如果是在新的文件夹中写模型的话,需要添加__init__.py
文件,使得程序能够识别到自定义模型
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module("xxx.models." + model_name)
-
BARTModel
继承TransformerModel
,在init
时运用BERT的参数随机初始化
from fairseq.models.transformer import TransformerModel
class BARTModel(TransformerModel):
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.apply(init_bert_params)
-
upgrade_state_dict_named
读入已经训练好的预训练参数时,比如bart.large,需要对数据进行一个处理。因为原本在训练BART时是通过预测mask内容训练embedding的,我们在finetune时是不需要mask标识的,所以这里要去除掉最后添加的mask标识
if (
loaded_dict_size == len(self.encoder.dictionary) + 1
and "<mask>" not in self.encoder.dictionary
):
truncate_emb("encoder.embed_tokens.weight")
truncate_emb("decoder.embed_tokens.weight")
truncate_emb("encoder.output_projection.weight")
truncate_emb("decoder.output_projection.weight")
这里同步去掉output_projection.weight最后一行,是因为weight tying,令pre-softmax的权重等于embedding层的权重。
自定义模型时若添加了其他标识,在读入预训练模型时embedding层要做好对齐。
hub_interface.py
-
encode
在前后端添加<s>
以及</s>
标识,将文本通过dictionary转变为模型可识别的数字 -
decode
去掉<s>
,并将数字通过dictionary变为文字 -
generate
以beam_size
进行beam search,测试时生成最终文本
下一篇(四)Fairseq任务,主要介绍fairseq自定义任务。