From 618cdcaa9f3a55313abdfce6bc802d787b97023f Mon Sep 17 00:00:00 2001 From: "chongjiu.jin" Date: Fri, 10 Jan 2020 15:28:47 +0800 Subject: [PATCH] update to transformers 2.3.0 --- pytorch-bert-code/bert/README.md | 8 ++- ..._bert_original_tf_checkpoint_to_pytorch.py | 61 +++++++++++++++++++ pytorch-bert-code/bert/run.sh | 4 +- 3 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 pytorch-bert-code/bert/convert_bert_original_tf_checkpoint_to_pytorch.py diff --git a/pytorch-bert-code/bert/README.md b/pytorch-bert-code/bert/README.md index 098ea2b..d8e156f 100644 --- a/pytorch-bert-code/bert/README.md +++ b/pytorch-bert-code/bert/README.md @@ -1,8 +1,14 @@ update to transformer 2.3.0 +### 如何将bert model 的Tensorflow模型 转换为pytorch模型 -转换工具已经失效 +convert_bert_original_tf_checkpoint_to_pytorch.py +运行脚本run.sh + +后生成对应pytorch_model.bin + +--- chinese bert https://github.com/ymcui/Chinese-BERT-wwm/blob/master/README_EN.md diff --git a/pytorch-bert-code/bert/convert_bert_original_tf_checkpoint_to_pytorch.py b/pytorch-bert-code/bert/convert_bert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000..806ace5 --- /dev/null +++ b/pytorch-bert-code/bert/convert_bert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BERT checkpoint.""" + + +import argparse +import logging + +import torch + +from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert + + +logging.basicConfig(level=logging.INFO) + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = BertConfig.from_json_file(bert_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = BertForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_bert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/pytorch-bert-code/bert/run.sh b/pytorch-bert-code/bert/run.sh index 5574450..87f9a5e 100644 --- a/pytorch-bert-code/bert/run.sh +++ b/pytorch-bert-code/bert/run.sh @@ -1,3 +1 @@ -export BERT_BASE_DIR=./ - -transformers bert $BERT_BASE_DIR/bert_model.ckpt $BERT_BASE_DIR/bert_config.json $BERT_BASE_DIR/pytorch_model.bin \ No newline at end of file +python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt --bert_config_file bert_config.json --pytorch_dump_path bert_model.bin