diff --git a/README.md b/README.md index 7e06429..caaee84 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ - 中文的GPT2训练代码,使用BERT的Tokenizer或Sentencepiece的BPE model(感谢[kangzhonghua](https://github.com/kangzhonghua)的贡献,实现BPE模式需要略微修改train.py的代码)。可以写诗,新闻,小说,或是训练通用语言模型。支持字为单位或是分词模式或是BPE模式(需要略微修改train.py的代码)。支持大语料训练。 - 微信交流群:请见Issue第一条。 +## NEWS 11.9 + +- [GPT2-ML](https://github.com/imcaspar/gpt2-ml)(与本项目无任何直接关联)已发布,包含1.5B模型。大家如有兴趣或需要可利用scripts文件夹内convert_from_tf_to_pytorch.py的脚本转换为本项目支持的Pytorch格式进行进一步训练或生成测试。 + ## UPDATE 10.25 - 本项目第一个预训练模型已公布,为散文生成模型,具体可查看README模型分享部分。 diff --git a/scripts/convert_fron_tf_to_pytorch.py b/scripts/convert_fron_tf_to_pytorch.py new file mode 100644 index 0000000..1b48a49 --- /dev/null +++ b/scripts/convert_fron_tf_to_pytorch.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +from __future__ import absolute_import, division, print_function + +import argparse +from io import open + +import torch + +from transformers import (CONFIG_NAME, WEIGHTS_NAME, + GPT2Config, + GPT2Model, + load_tf_weights_in_gpt2) + +import logging +logging.basicConfig(level=logging.INFO) + + +def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): + # Construct model + if gpt2_config_file == "": + config = GPT2Config() + else: + config = GPT2Config.from_json_file(gpt2_config_file) + model = GPT2Model(config) + + # Load weights from numpy + load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME + print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) + torch.save(model.state_dict(), pytorch_weights_dump_path) + print("Save configuration file to {}".format(pytorch_config_dump_path)) + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ## Required parameters + parser.add_argument("--gpt2_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path to the TensorFlow checkpoint path.") + parser.add_argument("--pytorch_dump_folder_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + parser.add_argument("--gpt2_config_file", + default = "", + type = str, + help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" + "This specifies the model architecture.") + args = parser.parse_args() + convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, + args.gpt2_config_file, + args.pytorch_dump_folder_path) \ No newline at end of file