Tensorflow -> Pytorch Bert预训练模型转换

前因

在寻找pytorch版本的英文版Bert预训练模型时,发现只有中文版的预训练模型,而且因为Tensorflow和Pytorch在读取预训练模型时,读取文件的格式不同,所以不能直接拿来使用。在读取这篇 Pytorch | BERT模型实现,提供转换脚本【横扫NLP】 文章后,发现有可以将预训练模型转换至Pytorch可以读取的文件形式的方法。

介绍

一个名为 Hugging Face 🤗 的团队公开了BERT模型的谷歌官方TensorFlow库的 op-for-op PyTorch 重新实现,其中有脚本可以将Tensorflow预训练模型转换为Pytorch可以读取的形式

这个实现可以为BERT加载任何预训练的TensorFlow checkpoint(特别是谷歌的官方预训练模型),并提供一个转换脚本。

使用说明

下载 Transformers,使用convert_bert_original_tf_checkpoint_to_pytorch.py脚本,你可以在PyTorch保存文件中转换BERT的任何TensorFlow检查点(尤其是谷歌发布的官方预训练模型)。

image.png

这个脚本将TensorFlow checkpoint(以bert_model.ckpt开头的三个文件)和相关的配置文件(bert_config.json)作为输入,并为此配置创建PyTorch模型,从PyTorch模型的TensorFlow checkpoint加载权重并保存生成的模型在一个标准PyTorch保存文件中,可以使用 torch.load() 导入(请参阅extract_features.py,run_classifier.py和run_squad.py中的示例)。

只需要运行一次这个转换脚本,在原文件夹下就可以得到一个PyTorch模型。然后,你可以忽略TensorFlow checkpoint(以bert_model.ckpt开头的三个文件),但是一定要保留配置文件(bert_config.json)和词汇表文件(vocab.txt),因为PyTorch模型也需要这些文件。

要运行这个特定的转换脚本,你需要安装TensorFlow和PyTorch。该库的其余部分只需要PyTorch。

使用方法

linux 下执行或者使用 git 执行 sh 指令
下面是一个预训练的BERT-Base Uncased 模型的转换过程示例:

# export BERT_BASE_DIR = '绝对路径'
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# windows下 
export BERT_BASE_DIR = F:/program/uncased_L-12_H-768_A-12

python convert_bert_original_tf_checkpoint_to_pytorch.py \
  --tf_checkpoint_path  $BERT_BASE_DIR/bert_model.ckpt   \
  --bert_config_file  $BERT_BASE_DIR/bert_config.json  \
  --pytorch_dump_path  $BERT_BASE_DIR/pytorch_model.bin \

Google的预训练模型下载地址:https://github.com/google-research/bert#pre-trained-models

参考文章

Pytorch | BERT模型实现,提供转换脚本【横扫NLP】

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。