diff --git a/README.md b/README.md index 6d48b48..faf9cf2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ## NEWS 11.9 -- [GPT2-ML](https://github.com/imcaspar/gpt2-ml)(与本项目无任何直接关联)已发布,包含1.5B中文GPT2模型。大家如有兴趣或需要可利用scripts文件夹内convert_from_tf_to_pytorch.py的脚本转换为本项目支持的Pytorch格式进行进一步训练或生成测试。 +- [GPT2-ML](https://github.com/imcaspar/gpt2-ml)(与本项目无任何直接关联)已发布,包含1.5B中文GPT2模型。大家如有兴趣或需要可将其转换为本项目支持的Pytorch格式进行进一步训练或生成测试。 ## UPDATE 10.25 diff --git a/scripts/convert_fron_tf_to_pytorch.py b/scripts/convert_fron_tf_to_pytorch.py deleted file mode 100644 index 1b48a49..0000000 --- a/scripts/convert_fron_tf_to_pytorch.py +++ /dev/null @@ -1,75 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OpenAI GPT checkpoint.""" - -from __future__ import absolute_import, division, print_function - -import argparse -from io import open - -import torch - -from transformers import (CONFIG_NAME, WEIGHTS_NAME, - GPT2Config, - GPT2Model, - load_tf_weights_in_gpt2) - -import logging -logging.basicConfig(level=logging.INFO) - - -def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): - # Construct model - if gpt2_config_file == "": - config = GPT2Config() - else: - config = GPT2Config.from_json_file(gpt2_config_file) - model = GPT2Model(config) - - # Load weights from numpy - load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME - print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) - torch.save(model.state_dict(), pytorch_weights_dump_path) - print("Save configuration file to {}".format(pytorch_config_dump_path)) - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - ## Required parameters - parser.add_argument("--gpt2_checkpoint_path", - default = None, - type = str, - required = True, - help = "Path to the TensorFlow checkpoint path.") - parser.add_argument("--pytorch_dump_folder_path", - default = None, - type = str, - required = True, - help = "Path to the output PyTorch model.") - parser.add_argument("--gpt2_config_file", - default = "", - type = str, - help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" - "This specifies the model architecture.") - args = parser.parse_args() - convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, - args.gpt2_config_file, - args.pytorch_dump_folder_path) \ No newline at end of file