the first commit
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
cache/.DS_Store
|
||||
.idea/workspace.xml
|
||||
.idea/misc.xml
|
||||
.idea/GPT2-Chinese.iml
|
||||
data/
|
||||
.samples.txt
|
||||
.idea/modules.xml
|
||||
.idea/vcs.xml
|
||||
.idea
|
||||
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Zeyao Du
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
112
README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# GPT2-Chinese
|
||||
|
||||
## Description
|
||||
|
||||
- Chinese version of GPT2 training code, using BERT tokenizer or BPE tokenizer. It is based on the extremely awesome repository from HuggingFace team [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers). Can write poems, news, novels, or train general language models. Support char level, word level and BPE level. Support large training corpus.
|
||||
- 中文的GPT2训练代码,使用BERT的Tokenizer或GPT2自带的BPE Tokenizer或Sentencepiece的BPE model(感谢[kangzhonghua](https://github.com/kangzhonghua)的贡献)。可以写诗,新闻,小说,或是训练通用语言模型。支持字为单位或是分词模式或是BPE模式。支持大语料训练。
|
||||
- 微信交流群:请加微信xiyuewang8912,输入NLP并发送,自动拉你入群。
|
||||
|
||||
## UPDATE 10.15
|
||||
|
||||
- 由用户[JamesHujy](https://github.com/JamesHujy)根据本仓库改版代码训练得到的模型作为律诗与绝句后台,新版[九歌诗歌生成器](https://jiuge.thunlp.cn/lvshi.html)已经上线。
|
||||
|
||||
## 项目状态
|
||||
|
||||
- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交Issue,PR或是联系作者。
|
||||
- 如使用梯度积累,loss计算可能存在bug。
|
||||
|
||||
## 使用方法
|
||||
|
||||
- 在项目根目录建立data文件夹。将训练语料以train.json为名放入data目录中。**train.json里是一个json列表,列表的每个元素都分别是一篇要训练的文章的文本内容(而不是文件链接)**。
|
||||
- 运行train.py文件,勾选 --raw ,会自动预处理数据。
|
||||
- 预处理完成之后,会自动执行训练。
|
||||
|
||||
### 生成文本
|
||||
|
||||
``` bash
|
||||
python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save_samples --save_samples_path=/mnt/xx
|
||||
```
|
||||
- **--fast_pattern** (由[LeeCP8](https://github.com/LeeCP8)贡献):如果生成的length参数比较小,速度基本无差别,我个人测试length=250时,快了2秒,所以如果不添加--fast_pattern,那么默认不采用fast_pattern方式。
|
||||
- **--save_samples**:默认将输出样本直接打印到控制台,传递此参数,将保存在根目录下的**samples.txt**。
|
||||
- **--save_samples_path**:可自行指定保存的目录,默认可递归创建多级目录,不可以传递文件名称,文件名称默认为**samples.txt**。
|
||||
|
||||
## 文件结构
|
||||
|
||||
- generate.py 与 train.py 分别是生成与训练的脚本。
|
||||
- train_single.py 是 train.py的延伸,可以用于一个很大的单独元素列表(如训练一本斗破苍穹书)。
|
||||
- eval.py 用于评估生成模型的ppl分值。
|
||||
- generate_texts.py 是 generate.py 的延伸,可以以一个列表的起始关键词分别生成若干个句子并输出到文件中。
|
||||
- train.json 是训练样本的格式范例,可供参考。
|
||||
- cache 文件夹内包含若干BERT词表,make_vocab.py 是一个协助在一个train.json语料文件上建立词表的脚本。 vocab.txt 是原始BERT词表, vocab_all.txt 额外添加了古文词, vocab_small.txt 是小词表。
|
||||
- tokenizations 文件夹内是可以选用的三种tokenizer,包括默认的Bert Tokenizer,分词版Bert Tokenizer以及BPE Tokenizer。
|
||||
- scripts 内包含了样例训练与生成脚本
|
||||
|
||||
## 注意
|
||||
|
||||
- 本项目使用Bert的tokenizer处理中文字符。
|
||||
- 如果使用分词版的tokenizer,不需要自己事先分词,tokenizer会帮你分。
|
||||
- 如果使用分词版的tokenizer,最好先使用cache文件夹内的make_vocab.py文件建立针对你的语料的词表。
|
||||
- 模型需自行运算。各位如果完成了预训练的话欢迎进行交流。
|
||||
- 如果你的内存非常大或者语料较小的话,可以改掉train.py内build files内的对应代码,不做拆分直接预处理语料。
|
||||
- 若使用BPE Tokenizer,需自己建立中文词表
|
||||
|
||||
## 语料
|
||||
|
||||
- 可以从[这里](https://github.com/brightmart/nlp_chinese_corpus)与[这里](http://thuctc.thunlp.org/#获取链接)下载。
|
||||
- 斗破苍穹语料可以从[这里](https://github.com/GaoPeng97/transformer-xl-chinese/tree/master/data/doupo)下载。
|
||||
|
||||
## FP16与Gradient Accumulation支持
|
||||
|
||||
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16不收敛,原因不明。
|
||||
|
||||
## 联系作者
|
||||
|
||||
- Mail:ned1991@gmail.com
|
||||
|
||||
## Citing
|
||||
|
||||
```
|
||||
@misc{GPT2-Chinese,
|
||||
author = {Zeyao Du},
|
||||
title = {GPT2-Chinese: Tools for training GPT2 model in Chinese language},
|
||||
year = {2019},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/Morizeyao/GPT2-Chinese}},
|
||||
}
|
||||
```
|
||||
|
||||
## Demo
|
||||
|
||||
- 由用户[JamesHujy](https://github.com/JamesHujy)根据本仓库改版代码训练得到的模型作为律诗与绝句后台,新版[九歌诗歌生成器](https://jiuge.thunlp.cn/lvshi.html)已经上线。
|
||||
- 由[leemengtaiwan](https://github.com/leemengtaiwan)贡献,提供[文章直觀介紹 GPT-2 以及如何視覺化自注意力機制](https://leemeng.tw/gpt2-language-model-generate-chinese-jing-yong-novels.html)。另提供 [Colab 筆記本與模型](https://colab.research.google.com/drive/1MaT8-HUHfZkdCra0OqZEIr0IFCq0MJBx)供任何使用者一鍵生成新樣例。
|
||||
|
||||
## 生成样例
|
||||
|
||||
- 下为斗破苍穹的生成样例,使用约50M参数的GPT2以32Batch Size在16MB斗破苍穹小说内容上训练得到。此处[SEP]表示换行。
|
||||
|
||||

|
||||
|
||||
- 下为古诗词的生成样例,由用户[JamesHujy](https://github.com/JamesHujy)运算并贡献。
|
||||
|
||||

|
||||

|
||||
|
||||
- 下为古诗限定了生成体裁后的生成样例,由用户[JamesHujy](https://github.com/JamesHujy)运算并贡献。
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||
- 下为生成剧本的样例文本,由用户[chiangandy](https://github.com/chiangandy)运算并贡献
|
||||
|
||||
[starttext]爱情游戏剧情讲述了钢琴父女明致怀萌的爱情、个有着努力的热情以及现实为人生的价值观众,获得一系列爱情的故事。80后录股媒体受到网友分享,是2014年主创陈拉昀出品牌总监于蓝氏集团化验师创业团门的哥哥大国度上海淮河畔,集入第一线公司青年度虽然没有放到的事业,但是蓝正是却不到位主人拒绝了解,而在蓝越的帮助理念出现,也因此开启明朗的误会而经营变成爱河。在一次偶然的编剧集电视剧之夏天上一改变了自命运环球顶樑,三人在创车祸中不知被记忆差网识分到创作,并被问流言败,以及行业服务所有的低调教同才力,陈昭和唐诗诗妍展开了一段截然不同的“2014年间段感情”,两人性格互相治癒的商业奋斗故事,尽管是共90后北京华侨大学录的一个宿舍小旅程和唐如、生等优秀青年,的人生活如何与愿违3个国偶像,并且共同创作何以此他们互相有观众的成功和关心吗?[endtext]
|
||||
|
||||
[starttext]学习爱情主要讲述了两对方小曼,经过啼笑皆非的考验,终于选择了三个孩子,携手共同创业来四个孩子,在大城市里创业的成功商。两家内事业的加入了北京城市,经过了一次元城市融风雨故、差异后得到异的他们,最终收获了梦想的真正属于自己的爱情。赞助理想、电视剧、剧等主创业时代人物特点在北京举行开机仪式,该剧以当下海南三个新人青年轻人面人海南梅竹马的电视角,讲述了几个在北京、喜剧代人生活中增强非浪漫的年轻人,以独特的双时代年轻人从来到北京城市化中国大城市走出发展以海南方的变迁在语种城市闯关于人生态的同时,以及他们渐渐的生活方式为自己方向上演了那么简单俗,是当代际拍摄的就如何在这个城市里都市里?那么平静的城市就是城市的风格特张嘉和支持工作打造,而这是一点就要打造出机场话剧组会。化身处处棋逢貌各种文化的人都非常独特的煽情,交织了相,滑稽等来自外衣的东北漂亮、内地,者和两位女孩子敢称是哑女孩子。交织里的人齐飞一开泰块玩笑,令人印象太趋的气质,让人眼看这个性格非常喜剧,知道的是一个“东北漂”人的外国小养家,让她耳熟练读剧的外形象显老大。之后齐飞、表示爱朗的齐飞、范儿、楚月子、白天杰。两代人的生活里友情似乎没有结合、精彩表态的开朗和丽丽丽。[endtext]
|
||||
|
||||
- 下為金庸武俠小說的生成樣例,由[leemengtaiwan](https://github.com/leemengtaiwan)贡献。模型大小約 82M,語料 50 MB,Batch size 16。提供[文章直觀介紹 GPT-2 以及如何視覺化自注意力機制](https://leemeng.tw/gpt2-language-model-generate-chinese-jing-yong-novels.html)。另提供 [Colab 筆記本與模型](https://colab.research.google.com/drive/1MaT8-HUHfZkdCra0OqZEIr0IFCq0MJBx)供任何使用者一鍵生成新樣例。
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
36
cache/make_vocab.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
import argparse
|
||||
import thulac
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
from keras.preprocessing.text import Tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--raw_data_path', default='../data/train.json', type=str, required=False, help='原始训练语料')
|
||||
parser.add_argument('--vocab_file', default='vocab_processed.txt', type=str, required=False, help='生成vocab链接')
|
||||
parser.add_argument('--vocab_size', default=50000, type=int, required=False, help='词表大小')
|
||||
args = parser.parse_args()
|
||||
|
||||
lac = thulac.thulac(seg_only=True)
|
||||
tokenizer = Tokenizer(num_words=args.vocab_size)
|
||||
print('args:\n' + args.__repr__())
|
||||
print('This script is extremely slow especially for large corpus. Take a break.')
|
||||
|
||||
f = open(args.raw_data_path, 'r')
|
||||
lines = json.load(f)
|
||||
for i, line in enumerate(tqdm(lines)):
|
||||
lines[i] = lac.cut(line, text=True)
|
||||
|
||||
tokenizer.fit_on_texts(lines)
|
||||
vocab = list(tokenizer.index_word.values())
|
||||
pre = ['[SEP]', '[CLS]', '[MASK]', '[PAD]', '[UNK]']
|
||||
vocab = pre + vocab
|
||||
with open(args.vocab_file, 'w') as f:
|
||||
for word in vocab[:args.vocab_size + 5]:
|
||||
f.write(word + '\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
cache/make_vocab.sh
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
python make_vocab.py \
|
||||
--raw_data_path ../data/train.json \
|
||||
--vocab_file vocab_user.txt \
|
||||
--vocab_size 50000
|
||||
21128
cache/vocab.txt
vendored
Normal file
19020
cache/vocab_all.txt
vendored
Normal file
32044
cache/vocab_guwen.txt
vendored
Normal file
49005
cache/vocab_seg.txt
vendored
Normal file
13317
cache/vocab_small.txt
vendored
Normal file
10
config/model_config.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_layer": 12,
|
||||
"n_positions": 1024,
|
||||
"vocab_size": 21128
|
||||
}
|
||||
10
config/model_config_small.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_layer": 10,
|
||||
"n_positions": 1024,
|
||||
"vocab_size": 13317
|
||||
}
|
||||
185
eval.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import pytorch_transformers
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length):
|
||||
if not os.path.exists(tokenized_data_path):
|
||||
os.mkdir(tokenized_data_path)
|
||||
with open(data_path, 'r', encoding='utf8') as f:
|
||||
print('reading lines')
|
||||
lines = json.load(f)
|
||||
lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
|
||||
all_len = len(lines)
|
||||
for i in tqdm(range(num_pieces)):
|
||||
sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)]
|
||||
if i == num_pieces - 1:
|
||||
sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece
|
||||
sublines = [full_tokenizer.tokenize(line) for line in sublines if
|
||||
len(line) > min_length] # 只考虑长度超过min_length的句子
|
||||
sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines]
|
||||
full_line = []
|
||||
for subline in sublines:
|
||||
full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始
|
||||
full_line.extend(subline)
|
||||
full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f:
|
||||
for id in full_line:
|
||||
f.write(str(id) + ' ')
|
||||
print('finish')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
|
||||
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
|
||||
help='选择模型参数')
|
||||
parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库')
|
||||
parser.add_argument('--raw_data_path', default='data/eval.json', type=str, required=False, help='原始语料')
|
||||
parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False,
|
||||
help='tokenized语料存放位置')
|
||||
parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
|
||||
parser.add_argument('--batch_size', default=8, type=int, required=False, help='batch size')
|
||||
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次')
|
||||
parser.add_argument('--stride', default=768, type=int, required=False, help='取数据的窗口步长')
|
||||
parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份')
|
||||
parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度')
|
||||
parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型起点路径')
|
||||
parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
|
||||
parser.add_argument('--output_dir', default='eval_result/', type=str, required=False, help='结果输出路径')
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:\n' + args.__repr__())
|
||||
|
||||
if args.no_wordpiece:
|
||||
from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert
|
||||
else:
|
||||
from tokenizations import tokenization_bert
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
|
||||
model_config = pytorch_transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
|
||||
print('config:\n' + model_config.to_json_string())
|
||||
|
||||
n_ctx = model_config.n_ctx
|
||||
full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
||||
full_tokenizer.max_len = n_ctx
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print('using device:', device)
|
||||
|
||||
raw_data_path = args.raw_data_path
|
||||
tokenized_data_path = args.tokenized_data_path
|
||||
raw = args.raw # 选择是否从零开始构建数据集
|
||||
batch_size = args.batch_size
|
||||
log_step = args.log_step
|
||||
stride = args.stride
|
||||
num_pieces = args.num_pieces
|
||||
min_length = args.min_length
|
||||
output_dir = args.output_dir
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
if raw:
|
||||
print('building files')
|
||||
build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces,
|
||||
full_tokenizer=full_tokenizer, min_length=min_length)
|
||||
print('files built')
|
||||
|
||||
if not args.pretrained_model:
|
||||
print('you need to specify a trained model.')
|
||||
exit(1)
|
||||
else:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
num_parameters = 0
|
||||
parameters = model.parameters()
|
||||
for parameter in parameters:
|
||||
num_parameters += parameter.numel()
|
||||
print('number of parameters: {}'.format(num_parameters))
|
||||
|
||||
multi_gpu = False
|
||||
full_len = 0
|
||||
print('calculating total steps')
|
||||
for i in tqdm(range(num_pieces)):
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
full_len += len([int(item) for item in f.read().strip().split()])
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = DataParallel(model)
|
||||
multi_gpu = True
|
||||
print('starting training')
|
||||
overall_step = 0
|
||||
|
||||
total_loss = 0
|
||||
total_steps = 0
|
||||
# eval
|
||||
now = datetime.now()
|
||||
print('time: {}'.format(now))
|
||||
piece_num = 0
|
||||
for i in range(num_pieces):
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
line = f.read().strip()
|
||||
tokens = line.split()
|
||||
tokens = [int(token) for token in tokens]
|
||||
start_point = 0
|
||||
samples = []
|
||||
while start_point < len(tokens) - n_ctx:
|
||||
samples.append(tokens[start_point: start_point + n_ctx])
|
||||
start_point += stride
|
||||
start_point -= stride
|
||||
last = tokens[start_point + n_ctx:]
|
||||
last.extend([full_tokenizer.convert_tokens_to_ids(['[PAD]']) * (n_ctx - len(last))])
|
||||
random.shuffle(samples)
|
||||
for step in range(len(samples) // batch_size): # drop last
|
||||
|
||||
# prepare data
|
||||
batch = samples[step * batch_size: (step + 1) * batch_size]
|
||||
batch_labels = []
|
||||
batch_inputs = []
|
||||
for ids in batch:
|
||||
int_ids_for_labels = [int(x) for x in ids]
|
||||
int_ids_for_inputs = [int(x) for x in ids]
|
||||
batch_labels.append(int_ids_for_labels)
|
||||
batch_inputs.append(int_ids_for_inputs)
|
||||
batch_labels = torch.tensor(batch_labels).long().to(device)
|
||||
batch_inputs = torch.tensor(batch_inputs).long().to(device)
|
||||
|
||||
# forward pass
|
||||
outputs = model.forward(input_ids=batch_inputs, labels=batch_labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
# get loss
|
||||
if multi_gpu:
|
||||
loss = loss.mean()
|
||||
total_loss += loss
|
||||
total_steps += 1
|
||||
|
||||
if (overall_step + 1) % log_step == 0:
|
||||
print('now time: {}:{}. Step {} of piece {}, ppl {}'.format(
|
||||
datetime.now().hour,
|
||||
datetime.now().minute,
|
||||
(step + 1),
|
||||
piece_num,
|
||||
torch.exp(loss)))
|
||||
piece_num += 1
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
else:
|
||||
with open(args.output_dir + 'result.txt', 'w') as f:
|
||||
f.write(np.exp(total_loss / total_steps))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
222
generate.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
import argparse
|
||||
from tqdm import trange
|
||||
from pytorch_transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
def is_word(word):
|
||||
for item in list(word):
|
||||
if item not in 'qwertyuiopasdfghjklzxcvbnm':
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_chinese_char(char):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
cp = ord(char)
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (vocabulary size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
|
||||
device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0)
|
||||
generated = context
|
||||
with torch.no_grad():
|
||||
for _ in trange(length):
|
||||
inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
|
||||
outputs = model(
|
||||
**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :]
|
||||
for id in set(generated):
|
||||
next_token_logits[id] /= repitition_penalty
|
||||
next_token_logits = next_token_logits / temperature
|
||||
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
||||
return generated.tolist()[0]
|
||||
|
||||
|
||||
def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):
|
||||
inputs = torch.LongTensor(context).view(1, -1).to(device)
|
||||
if len(context) > 1:
|
||||
_, past = model(inputs[:, :-1], None)[:2]
|
||||
prev = inputs[:, -1].view(1, -1)
|
||||
else:
|
||||
past = None
|
||||
prev = inputs
|
||||
generate = [] + context
|
||||
with torch.no_grad():
|
||||
for i in trange(length):
|
||||
output = model(prev, past=past)
|
||||
output, past = output[:2]
|
||||
output = output[-1].squeeze(0) / temperature
|
||||
filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p)
|
||||
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generate.append(next_token.item())
|
||||
prev = next_token.view(1, 1)
|
||||
return generate
|
||||
|
||||
|
||||
# 通过命令行参数--fast_pattern,指定模式
|
||||
def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
|
||||
is_fast_pattern=False):
|
||||
if is_fast_pattern:
|
||||
return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p,
|
||||
device=device)
|
||||
else:
|
||||
return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p,
|
||||
repitition_penalty=repitition_penalty, device=device)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备')
|
||||
parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
|
||||
parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
|
||||
parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本')
|
||||
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
|
||||
parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一')
|
||||
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
|
||||
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
|
||||
help='模型参数')
|
||||
parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径')
|
||||
parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径')
|
||||
parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头')
|
||||
parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
|
||||
parser.add_argument('--segment', action='store_true', help='中文以词为单位')
|
||||
parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
|
||||
parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
|
||||
parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径")
|
||||
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:\n' + args.__repr__())
|
||||
|
||||
if args.no_wordpiece:
|
||||
from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert
|
||||
elif args.segment:
|
||||
from tokenizations import tokenization_bert_word_level as tokenization_bert
|
||||
else:
|
||||
from tokenizations import tokenization_bert
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
length = args.length
|
||||
batch_size = args.batch_size
|
||||
nsamples = args.nsamples
|
||||
temperature = args.temperature
|
||||
topk = args.topk
|
||||
topp = args.topp
|
||||
repetition_penalty = args.repetition_penalty
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
n_ctx = model.config.n_ctx
|
||||
|
||||
if length == -1:
|
||||
length = model.config.n_ctx
|
||||
if args.save_samples:
|
||||
if not os.path.exists(args.save_samples_path):
|
||||
os.makedirs(args.save_samples_path)
|
||||
samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8')
|
||||
while True:
|
||||
raw_text = args.prefix
|
||||
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
|
||||
generated = 0
|
||||
for _ in range(nsamples // batch_size):
|
||||
out = generate(
|
||||
n_ctx=n_ctx,
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
length=length,
|
||||
is_fast_pattern=args.fast_pattern, tokenizer=tokenizer,
|
||||
temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
|
||||
)
|
||||
for i in range(batch_size):
|
||||
generated += 1
|
||||
text = tokenizer.convert_ids_to_tokens(out)
|
||||
for i, item in enumerate(text[:-1]): # 确保英文前后有空格
|
||||
if is_word(item) and is_word(text[i + 1]):
|
||||
text[i] = item + ' '
|
||||
for i, item in enumerate(text):
|
||||
if item == '[MASK]':
|
||||
text[i] = ''
|
||||
if item == '[CLS]' or item == '[SEP]':
|
||||
text[i] = '\n'
|
||||
info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
|
||||
print(info)
|
||||
text = ''.join(text).replace('##', '').strip()
|
||||
print(text)
|
||||
if args.save_samples:
|
||||
samples_file.write(info)
|
||||
samples_file.write(text)
|
||||
samples_file.write('\n')
|
||||
samples_file.write('=' * 90)
|
||||
samples_file.write('\n' * 2)
|
||||
print("=" * 80)
|
||||
if generated == nsamples:
|
||||
# close file when finish writing.
|
||||
if args.save_samples:
|
||||
samples_file.close()
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
190
generate_texts.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_transformers
|
||||
import os
|
||||
import argparse
|
||||
from tqdm import trange
|
||||
from pytorch_transformers import GPT2LMHeadModel
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 此处设置程序使用哪些显卡
|
||||
|
||||
|
||||
def is_word(word):
|
||||
for item in list(word):
|
||||
if item not in 'qwertyuiopasdfghjklzxcvbnm':
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_chinese_char(char):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
cp = ord(char)
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (vocabulary size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
|
||||
device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0)
|
||||
generated = context
|
||||
with torch.no_grad():
|
||||
for _ in trange(length):
|
||||
inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
|
||||
outputs = model(
|
||||
**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :]
|
||||
for id in set(generated):
|
||||
next_token_logits[id] /= repitition_penalty
|
||||
next_token_logits = next_token_logits / temperature
|
||||
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
||||
return generated
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
|
||||
parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
|
||||
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度,越高越随机')
|
||||
parser.add_argument('--topk', default=8, type=int, required=False, help='生成的时候最高几选一')
|
||||
parser.add_argument('--topp', default=0, type=float, required=False, help='生成的时候积累概率最高多少')
|
||||
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
|
||||
help='模型参数路径')
|
||||
parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径')
|
||||
parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径')
|
||||
parser.add_argument('--save_path', default='generated/', type=str, required=False, help='存放生成的文件的路径')
|
||||
parser.add_argument('--articles_per_title', default=5, type=int, required=False, help='每个标题生成多少篇文章')
|
||||
parser.add_argument('--titles', default='萧炎', type=str, required=False, help='标题列表,是一个字符串,用空格分开')
|
||||
parser.add_argument('--titles_file', default='', type=str, required=False,
|
||||
help='标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效')
|
||||
parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
|
||||
parser.add_argument('--segment', action='store_true', help='中文以词为单位')
|
||||
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:\n' + args.__repr__())
|
||||
|
||||
if args.no_wordpiece:
|
||||
from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert
|
||||
elif args.segment:
|
||||
from tokenizations import tokenization_bert_word_level as tokenization_bert
|
||||
else:
|
||||
from tokenizations import tokenization_bert
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
length = args.length
|
||||
temperature = args.temperature
|
||||
topk = args.topk
|
||||
topp = args.topp
|
||||
repetition_penalty = args.repetition_penalty
|
||||
|
||||
titles = args.titles.split() # 列表,里面每个元素是一个生成的标题
|
||||
if args.titles_file:
|
||||
with open(args.titles_file, 'r') as f:
|
||||
titles = [line.strip('\n') for line in f.readlines()]
|
||||
articles_per_title = args.articles_per_title # 这里定义一个标题生成多少篇文章
|
||||
save_path = args.save_path # 设置存到哪
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
||||
model_config = pytorch_transformers.GPT2Config.from_json_file(args.model_config)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
n_ctx = model.config.n_ctx
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
if length == -1:
|
||||
length = model.config.n_ctx
|
||||
|
||||
for i, title in enumerate(titles):
|
||||
for j in range(articles_per_title):
|
||||
with open(save_path + str(i * j), 'w') as f:
|
||||
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(title))
|
||||
generated = 0
|
||||
out = sample_sequence(
|
||||
n_ctx=n_ctx,
|
||||
model=model, length=length,
|
||||
context=context_tokens, tokenizer=tokenizer,
|
||||
temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty,
|
||||
device=device
|
||||
)
|
||||
out = out.tolist()[0]
|
||||
|
||||
generated += 1
|
||||
text = tokenizer.convert_ids_to_tokens(out)
|
||||
|
||||
for i, item in enumerate(text[:-1]): # 确保英文前后有空格
|
||||
if is_word(item) and is_word(text[i + 1]):
|
||||
text[i] = item + ' '
|
||||
|
||||
for i, item in enumerate(text):
|
||||
if item == '[MASK]':
|
||||
text[i] = ''
|
||||
if item == '[CLS]' or item == '[SEP]':
|
||||
text[i] = '\n'
|
||||
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
text = ''.join(text).replace('##', '').strip()
|
||||
# text = ''.join(text.split('\n')[:-1])
|
||||
print(text)
|
||||
f.write(text)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
9
requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
pytorch-transformers
|
||||
torch
|
||||
numpy
|
||||
tqdm
|
||||
sklearn
|
||||
keras
|
||||
tb-nightly
|
||||
future
|
||||
thulac
|
||||
BIN
sample/doupo.jpeg
Normal file
|
After Width: | Height: | Size: 5.0 MiB |
BIN
sample/poem_1.png
Normal file
|
After Width: | Height: | Size: 621 KiB |
BIN
sample/poem_2.png
Normal file
|
After Width: | Height: | Size: 563 KiB |
BIN
sample/tiyu.jpg
Normal file
|
After Width: | Height: | Size: 1.0 MiB |
BIN
sample/律诗绝句.png
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
BIN
sample/浣溪沙_江城子.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
BIN
sample/蝶恋花_满江红.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
BIN
sample/金庸_倚天屠龍記.jpg
Normal file
|
After Width: | Height: | Size: 87 KiB |
BIN
sample/金庸_天龍八部.jpg
Normal file
|
After Width: | Height: | Size: 78 KiB |
BIN
sample/金庸_神鵰俠侶.jpg
Normal file
|
After Width: | Height: | Size: 76 KiB |
BIN
sample/金庸_鹿鼎記.jpg
Normal file
|
After Width: | Height: | Size: 87 KiB |
8
scripts/generate.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
python generate_texts_seo.py \
|
||||
--device 0 \
|
||||
--length 900 \
|
||||
--tokenizer_path cache/vocab_small.txt \
|
||||
--model_path model/final_model \
|
||||
--prefix "[CLS][MASK]" \
|
||||
--topp 1 \
|
||||
--temperature 1.0
|
||||
12
scripts/train.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
python train.py \
|
||||
--model_config config/model_config_small.json \
|
||||
--tokenized_data_path data/tokenized/ \
|
||||
--tokenizer_path cache/vocab_small.txt \
|
||||
--raw_data_path data/train.json \
|
||||
--epochs 30 \
|
||||
--log_step 200 \
|
||||
--stride 512 \
|
||||
--output_dir model/ \
|
||||
--device 0,1,2,3 \
|
||||
--num_pieces 100 \
|
||||
--raw
|
||||
142
tokenizations/bpe_tokenizer.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
from https://github.com/openai/gpt-2/, changed for chinese
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sentencepiece as spm
|
||||
"""
|
||||
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation
|
||||
systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements
|
||||
subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the
|
||||
extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end
|
||||
system that does not depend on language-specific pre/postprocessing.
|
||||
https://github.com/google/sentencepiece
|
||||
|
||||
pip install sentencepiece
|
||||
|
||||
or git clone https://github.com/google/sentencepiece.git
|
||||
python setup.py install
|
||||
|
||||
"""
|
||||
|
||||
def get_pairs(word):
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
self.max_len = 0
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
return [self.encoder.get(token, 1) for token in self.tokenize(text)]
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
return text
|
||||
|
||||
def tokenize(self, text):
|
||||
bpe_tokens = []
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.encoder.get(token, 1) for token in tokens]
|
||||
|
||||
class Encoder_SP:
|
||||
def __init__(self, model_path):
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(model_path)
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
"""
|
||||
text="...."
|
||||
"""
|
||||
return self.sp.EncodeAsIds(text)
|
||||
|
||||
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
tokens=[x1,x2,...]
|
||||
"""
|
||||
text = [int(token) for token in tokens]
|
||||
#print(text)
|
||||
return self.sp.DecodeIds(text)
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.sp.EncodeAsPieces(text)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.sp.PieceToId(token) for token in tokens]
|
||||
|
||||
def get_encoder(encoder_file, bpe_file):
|
||||
|
||||
#以下是为了同一个函数入兼容sentencepiece
|
||||
filepath, filename = os.path.split(encoder_file)
|
||||
shotname, extension = os.path.splitext(filename)
|
||||
|
||||
if(".model" == extension) and (bpe_file == ""):
|
||||
return Encoder_SP(encoder_path)
|
||||
else:
|
||||
with open(encoder_file, 'r', encoding="utf-8") as f:
|
||||
encoder = json.load(f)
|
||||
with open(bpe_file, 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
1
tokenizations/encoder.json
Normal file
@@ -0,0 +1 @@
|
||||
{"c":0, "d":1, "大学":2}
|
||||
5
tokenizations/thulac_dict/seg
Normal file
@@ -0,0 +1,5 @@
|
||||
[SEP]
|
||||
[PAD]
|
||||
[CLS]
|
||||
[UNK]
|
||||
[MASK]
|
||||
436
tokenizations/tokenization_bert.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-uncased': 512,
|
||||
'bert-large-uncased': 512,
|
||||
'bert-base-cased': 512,
|
||||
'bert-large-cased': 512,
|
||||
'bert-base-multilingual-uncased': 512,
|
||||
'bert-base-multilingual-cased': 512,
|
||||
'bert-base-chinese': 512,
|
||||
'bert-base-german-cased': 512,
|
||||
'bert-large-uncased-whole-word-masking': 512,
|
||||
'bert-large-cased-whole-word-masking': 512,
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
|
||||
'bert-base-cased-finetuned-mrpc': 512,
|
||||
}
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip('\n')
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Constructs a BertTokenizer.
|
||||
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
max_len: An artificial maximum length to truncate tokenized_doupo sequences to; Effective maximum length is always the
|
||||
minimum of this value (if specified) and the underlying BERT model's sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization. Only has an effect when
|
||||
do_wordpiece_only=False
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
|
||||
"""Constructs a BertTokenizer.
|
||||
|
||||
Args:
|
||||
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
|
||||
**do_lower_case**: (`optional`) boolean (default True)
|
||||
Whether to lower case the input
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
**do_basic_tokenize**: (`optional`) boolean (default True)
|
||||
Whether to do basic tokenization before wordpiece.
|
||||
**never_split**: (`optional`) list of string
|
||||
List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
||||
Whether to tokenize Chinese characters.
|
||||
This should likely be desactivated for Japanese:
|
||||
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
||||
"""
|
||||
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||
never_split=never_split,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def _tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
|
||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
|
||||
""" Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
**do_lower_case**: Whether to lower case the input.
|
||||
**never_split**: (`optional`) list of str
|
||||
Kept for backward compatibility purposes.
|
||||
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
||||
List of token not to split.
|
||||
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
||||
Whether to tokenize Chinese characters.
|
||||
This should likely be desactivated for Japanese:
|
||||
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
||||
"""
|
||||
if never_split is None:
|
||||
never_split = []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||
|
||||
def tokenize(self, text, never_split=None):
|
||||
""" Basic Tokenization of a piece of text.
|
||||
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
|
||||
|
||||
Args:
|
||||
**never_split**: (`optional`) list of str
|
||||
Kept for backward compatibility purposes.
|
||||
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
||||
List of token not to split.
|
||||
"""
|
||||
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
if self.tokenize_chinese_chars:
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp) or char.isdigit():
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
453
tokenizations/tokenization_bert_word_level.py
Normal file
@@ -0,0 +1,453 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
import thulac
|
||||
from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
lac = thulac.thulac(user_dict='tokenizations/thulac_dict/seg', seg_only=True)
|
||||
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-uncased': 512,
|
||||
'bert-large-uncased': 512,
|
||||
'bert-base-cased': 512,
|
||||
'bert-large-cased': 512,
|
||||
'bert-base-multilingual-uncased': 512,
|
||||
'bert-base-multilingual-cased': 512,
|
||||
'bert-base-chinese': 512,
|
||||
'bert-base-german-cased': 512,
|
||||
'bert-large-uncased-whole-word-masking': 512,
|
||||
'bert-large-cased-whole-word-masking': 512,
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
|
||||
'bert-base-cased-finetuned-mrpc': 512,
|
||||
}
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
tokens = reader.readlines()
|
||||
for index, token in enumerate(tokens):
|
||||
token = token.rstrip('\n')
|
||||
vocab[token] = index
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Constructs a BertTokenizer.
|
||||
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
max_len: An artificial maximum length to truncate tokenized_doupo sequences to; Effective maximum length is always the
|
||||
minimum of this value (if specified) and the underlying BERT model's sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization. Only has an effect when
|
||||
do_wordpiece_only=False
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
|
||||
"""Constructs a BertTokenizer.
|
||||
|
||||
Args:
|
||||
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
|
||||
**do_lower_case**: (`optional`) boolean (default True)
|
||||
Whether to lower case the input
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
**do_basic_tokenize**: (`optional`) boolean (default True)
|
||||
Whether to do basic tokenization before wordpiece.
|
||||
**never_split**: (`optional`) list of string
|
||||
List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_basic_tokenize=True
|
||||
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
||||
Whether to tokenize Chinese characters.
|
||||
This should likely be desactivated for Japanese:
|
||||
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
||||
"""
|
||||
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||
never_split=never_split,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def _tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
if os.path.isdir(vocab_path):
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
|
||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
|
||||
""" Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
**do_lower_case**: Whether to lower case the input.
|
||||
**never_split**: (`optional`) list of str
|
||||
Kept for backward compatibility purposes.
|
||||
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
||||
List of token not to split.
|
||||
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
||||
Whether to tokenize Chinese characters.
|
||||
This should likely be desactivated for Japanese:
|
||||
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
||||
"""
|
||||
if never_split is None:
|
||||
never_split = []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||
|
||||
def tokenize(self, text, never_split=None):
|
||||
""" Basic Tokenization of a piece of text.
|
||||
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
|
||||
|
||||
Args:
|
||||
**never_split**: (`optional`) list of str
|
||||
Kept for backward compatibility purposes.
|
||||
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
||||
List of token not to split.
|
||||
"""
|
||||
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
if self.tokenize_chinese_chars:
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
# def _tokenize_chinese_chars(self, text):
|
||||
# """Adds whitespace around any CJK character."""
|
||||
# output = []
|
||||
# for char in text:
|
||||
# cp = ord(char)
|
||||
# if self._is_chinese_char(cp) or char.isdigit():
|
||||
# output.append(" ")
|
||||
# output.append(char)
|
||||
# output.append(" ")
|
||||
# else:
|
||||
# output.append(char)
|
||||
# return "".join(output)
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
if char.isdigit():
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
text = "".join(output)
|
||||
text = [item[0].strip() for item in lac.cut(text)]
|
||||
text = [item for item in text if item]
|
||||
return " ".join(text)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
2
tokenizations/vocab.bpe
Normal file
@@ -0,0 +1,2 @@
|
||||
#version: 0.2
|
||||
大 学
|
||||
1
train.json
Normal file
@@ -0,0 +1 @@
|
||||
["a", "b", "c"]
|
||||
251
train.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import pytorch_transformers
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import argparse
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
from torch.nn import DataParallel
|
||||
from tokenizations.bpe_tokenizer import get_encoder
|
||||
|
||||
|
||||
def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length):
|
||||
with open(data_path, 'r', encoding='utf8') as f:
|
||||
print('reading lines')
|
||||
lines = json.load(f)
|
||||
lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
|
||||
all_len = len(lines)
|
||||
if not os.path.exists(tokenized_data_path):
|
||||
os.mkdir(tokenized_data_path)
|
||||
for i in tqdm(range(num_pieces)):
|
||||
sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)]
|
||||
if i == num_pieces - 1:
|
||||
sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece
|
||||
sublines = [full_tokenizer.tokenize(line) for line in sublines if
|
||||
len(line) > min_length] # 只考虑长度超过min_length的句子
|
||||
sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines]
|
||||
full_line = []
|
||||
for subline in sublines:
|
||||
full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始
|
||||
full_line.extend(subline)
|
||||
full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f:
|
||||
for id in full_line:
|
||||
f.write(str(id) + ' ')
|
||||
print('finish')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
|
||||
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
|
||||
help='选择模型参数')
|
||||
parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库')
|
||||
parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料')
|
||||
parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False,
|
||||
help='tokenized语料存放位置')
|
||||
parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
|
||||
parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环')
|
||||
parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size')
|
||||
parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率')
|
||||
parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
|
||||
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
|
||||
parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长')
|
||||
parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累')
|
||||
parser.add_argument('--fp16', action='store_true', help='混合精度')
|
||||
parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False)
|
||||
parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
|
||||
parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份')
|
||||
parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度')
|
||||
parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径')
|
||||
parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径')
|
||||
parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径')
|
||||
parser.add_argument('--segment', action='store_true', help='中文以词为单位')
|
||||
parser.add_argument('--bpe_token', action='store_true', help='subword')
|
||||
parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json")
|
||||
parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe")
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:\n' + args.__repr__())
|
||||
|
||||
if args.segment:
|
||||
from tokenizations import tokenization_bert_word_level as tokenization_bert
|
||||
else:
|
||||
from tokenizations import tokenization_bert
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
|
||||
model_config = pytorch_transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
|
||||
print('config:\n' + model_config.to_json_string())
|
||||
|
||||
n_ctx = model_config.n_ctx
|
||||
if args.bpe_token:
|
||||
full_tokenizer = get_encoder(args.encoder_json, args.vocab_bpe)
|
||||
else:
|
||||
full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
||||
full_tokenizer.max_len = 999999
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print('using device:', device)
|
||||
|
||||
raw_data_path = args.raw_data_path
|
||||
tokenized_data_path = args.tokenized_data_path
|
||||
raw = args.raw # 选择是否从零开始构建数据集
|
||||
epochs = args.epochs
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
warmup_steps = args.warmup_steps
|
||||
log_step = args.log_step
|
||||
stride = args.stride
|
||||
gradient_accumulation = args.gradient_accumulation
|
||||
fp16 = args.fp16 # 不支持半精度的显卡请勿打开
|
||||
fp16_opt_level = args.fp16_opt_level
|
||||
max_grad_norm = args.max_grad_norm
|
||||
num_pieces = args.num_pieces
|
||||
min_length = args.min_length
|
||||
output_dir = args.output_dir
|
||||
tb_writer = SummaryWriter(log_dir=args.writer_dir)
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
if raw:
|
||||
print('building files')
|
||||
build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces,
|
||||
full_tokenizer=full_tokenizer, min_length=min_length)
|
||||
print('files built')
|
||||
|
||||
if not args.pretrained_model:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config)
|
||||
else:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
model.train()
|
||||
model.to(device)
|
||||
|
||||
num_parameters = 0
|
||||
parameters = model.parameters()
|
||||
for parameter in parameters:
|
||||
num_parameters += parameter.numel()
|
||||
print('number of parameters: {}'.format(num_parameters))
|
||||
|
||||
multi_gpu = False
|
||||
full_len = 0
|
||||
print('calculating total steps')
|
||||
for i in tqdm(range(num_pieces)):
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
full_len += len([int(item) for item in f.read().strip().split()])
|
||||
total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation)
|
||||
print('total steps = {}'.format(total_steps))
|
||||
|
||||
optimizer = pytorch_transformers.AdamW(model.parameters(), lr=lr, correct_bias=True)
|
||||
scheduler = pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps,
|
||||
t_total=total_steps)
|
||||
if fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = DataParallel(model)
|
||||
multi_gpu = True
|
||||
print('starting training')
|
||||
overall_step = 0
|
||||
running_loss = 0
|
||||
for epoch in range(epochs):
|
||||
print('epoch {}'.format(epoch + 1))
|
||||
now = datetime.now()
|
||||
print('time: {}'.format(now))
|
||||
x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32)
|
||||
random.shuffle(x)
|
||||
piece_num = 0
|
||||
for i in x:
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
line = f.read().strip()
|
||||
tokens = line.split()
|
||||
tokens = [int(token) for token in tokens]
|
||||
start_point = 0
|
||||
samples = []
|
||||
while start_point < len(tokens) - n_ctx:
|
||||
samples.append(tokens[start_point: start_point + n_ctx])
|
||||
start_point += stride
|
||||
if start_point < len(tokens):
|
||||
samples.append(tokens[len(tokens)-n_ctx:])
|
||||
random.shuffle(samples)
|
||||
for step in range(len(samples) // batch_size): # drop last
|
||||
|
||||
# prepare data
|
||||
batch = samples[step * batch_size: (step + 1) * batch_size]
|
||||
batch_inputs = []
|
||||
for ids in batch:
|
||||
int_ids = [int(x) for x in ids]
|
||||
batch_inputs.append(int_ids)
|
||||
batch_inputs = torch.tensor(batch_inputs).long().to(device)
|
||||
|
||||
# forward pass
|
||||
outputs = model.forward(input_ids=batch_inputs, labels=batch_inputs)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
# get loss
|
||||
if multi_gpu:
|
||||
loss = loss.mean()
|
||||
if gradient_accumulation > 1:
|
||||
loss = loss / gradient_accumulation
|
||||
|
||||
# loss backward
|
||||
if fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
||||
|
||||
# optimizer step
|
||||
if (step + 1) % gradient_accumulation == 0:
|
||||
running_loss += loss.item()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
overall_step += 1
|
||||
if (overall_step + 1) % log_step == 0:
|
||||
tb_writer.add_scalar('loss', loss.item(), overall_step)
|
||||
if (overall_step + 1) % log_step == 0:
|
||||
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
|
||||
datetime.now().hour,
|
||||
datetime.now().minute,
|
||||
step + 1,
|
||||
piece_num,
|
||||
epoch + 1,
|
||||
running_loss * gradient_accumulation / log_step))
|
||||
running_loss = 0
|
||||
piece_num += 1
|
||||
|
||||
print('saving model for epoch {}'.format(epoch + 1))
|
||||
if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)):
|
||||
os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1))
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1))
|
||||
# torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1))
|
||||
# torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1))
|
||||
print('epoch {} finished'.format(epoch + 1))
|
||||
|
||||
then = datetime.now()
|
||||
print('time: {}'.format(then))
|
||||
print('time for one epoch: {}'.format(then - now))
|
||||
|
||||
print('training finished')
|
||||
if not os.path.exists(output_dir + 'final_model'):
|
||||
os.mkdir(output_dir + 'final_model')
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(output_dir + 'final_model')
|
||||
# torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt')
|
||||
# torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
227
train_single.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import pytorch_transformers
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from torch.nn import DataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
'''
|
||||
如果训练材料是全部堆在一起不分篇章的话用这个文件
|
||||
'''
|
||||
|
||||
|
||||
def build_files(raw_data_path, tokenized_data_path, full_tokenizer, num_pieces):
|
||||
with open(raw_data_path, 'r', encoding='utf8') as f:
|
||||
print('reading lines')
|
||||
lines = json.load(f)
|
||||
lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
|
||||
single = ''.join(lines)
|
||||
len_single = len(single)
|
||||
if not os.path.exists(tokenized_data_path):
|
||||
os.mkdir(tokenized_data_path)
|
||||
for i in tqdm(range(num_pieces)):
|
||||
single_ids = full_tokenizer.convert_tokens_to_ids(
|
||||
full_tokenizer.tokenize(single[len_single // num_pieces * i: len_single // num_pieces * (i + 1)]))
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f:
|
||||
for id in single_ids[:-1]:
|
||||
f.write(str(id) + ' ')
|
||||
f.write(str(single_ids[-1]))
|
||||
f.write('\n')
|
||||
|
||||
print('finish')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
|
||||
parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
|
||||
help='选择模型参数')
|
||||
parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库')
|
||||
parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料')
|
||||
parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False,
|
||||
help='tokenized语料存放位置')
|
||||
parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
|
||||
parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环')
|
||||
parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size')
|
||||
parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率')
|
||||
parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
|
||||
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
|
||||
parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长')
|
||||
parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累')
|
||||
parser.add_argument('--fp16', action='store_true', help='混合精度')
|
||||
parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False)
|
||||
parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
|
||||
parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份')
|
||||
parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径')
|
||||
parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径')
|
||||
parser.add_argument('--segment', action='store_true', help='中文以词为单位')
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:\n' + args.__repr__())
|
||||
|
||||
if args.segment:
|
||||
from tokenizations import tokenization_bert_word_level as tokenization_bert
|
||||
else:
|
||||
from tokenizations import tokenization_bert
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
model_config = pytorch_transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
|
||||
print('config:\n' + model_config.to_json_string())
|
||||
|
||||
n_ctx = model_config.n_ctx
|
||||
full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
||||
full_tokenizer.max_len = 999999
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print('using device:', device)
|
||||
|
||||
raw_data_path = args.raw_data_path
|
||||
tokenized_data_path = args.tokenized_data_path
|
||||
raw = args.raw # 选择是否从零开始构建数据集
|
||||
epochs = args.epochs
|
||||
batch_size = args.batch_size
|
||||
lr = args.lr
|
||||
warmup_steps = args.warmup_steps
|
||||
log_step = args.log_step
|
||||
stride = args.stride
|
||||
gradient_accumulation = args.gradient_accumulation
|
||||
fp16 = args.fp16 # 不支持半精度的显卡请勿打开
|
||||
fp16_opt_level = args.fp16_opt_level
|
||||
max_grad_norm = args.max_grad_norm
|
||||
num_pieces = args.num_pieces
|
||||
output_dir = args.output_dir
|
||||
|
||||
if raw:
|
||||
print('building files')
|
||||
build_files(raw_data_path=raw_data_path, tokenized_data_path=tokenized_data_path, full_tokenizer=full_tokenizer,
|
||||
num_pieces=num_pieces)
|
||||
print('files built')
|
||||
|
||||
if not args.pretrained_model:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config)
|
||||
else:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
model.train()
|
||||
model.to(device)
|
||||
multi_gpu = False
|
||||
full_len = 0
|
||||
print('calculating total steps')
|
||||
for i in tqdm(range(num_pieces)):
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
full_len += len([int(item) for item in f.read().strip().split()])
|
||||
total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation)
|
||||
print('total steps = {}'.format(total_steps))
|
||||
|
||||
optimizer = pytorch_transformers.AdamW(model.parameters(), lr=lr, correct_bias=True)
|
||||
scheduler = pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps,
|
||||
t_total=total_steps)
|
||||
if fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = DataParallel(model)
|
||||
multi_gpu = True
|
||||
print('starting training')
|
||||
running_loss = 0
|
||||
for epoch in range(epochs):
|
||||
print('epoch {}'.format(epoch + 1))
|
||||
now = datetime.now()
|
||||
print('time: {}'.format(now))
|
||||
x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32)
|
||||
random.shuffle(x)
|
||||
piece_num = 0
|
||||
for i in x:
|
||||
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
|
||||
line = f.read().strip()
|
||||
tokens = line.split()
|
||||
tokens = [int(token) for token in tokens]
|
||||
start_point = 0
|
||||
samples = []
|
||||
while start_point < len(tokens) - n_ctx:
|
||||
samples.append(tokens[start_point: start_point + n_ctx])
|
||||
start_point += stride
|
||||
if start_point < len(tokens):
|
||||
samples.append(tokens[len(tokens)-n_ctx:])
|
||||
random.shuffle(samples)
|
||||
for step in range(len(samples) // batch_size):
|
||||
|
||||
# prepare data
|
||||
batch = samples[step * batch_size: (step + 1) * batch_size]
|
||||
batch_labels = []
|
||||
batch_inputs = []
|
||||
for ids in batch:
|
||||
int_ids_for_labels = [int(x) for x in ids]
|
||||
int_ids_for_inputs = [int(x) for x in ids]
|
||||
batch_labels.append(int_ids_for_labels)
|
||||
batch_inputs.append(int_ids_for_inputs)
|
||||
batch_labels = torch.tensor(batch_labels).long().to(device)
|
||||
batch_inputs = torch.tensor(batch_inputs).long().to(device)
|
||||
|
||||
# forward pass
|
||||
outputs = model.forward(input_ids=batch_inputs, labels=batch_labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
# get loss
|
||||
if multi_gpu:
|
||||
loss = loss.mean()
|
||||
if gradient_accumulation > 1:
|
||||
loss = loss / gradient_accumulation
|
||||
|
||||
# loss backward
|
||||
if fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
||||
|
||||
# optimizer step
|
||||
if (step + 1) % gradient_accumulation == 0:
|
||||
running_loss += loss.item()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
if (step + 1) % log_step == 0:
|
||||
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
|
||||
datetime.now().hour,
|
||||
datetime.now().minute,
|
||||
(step + 1) // gradient_accumulation,
|
||||
piece_num,
|
||||
epoch + 1,
|
||||
running_loss * gradient_accumulation / log_step))
|
||||
running_loss = 0
|
||||
piece_num += 1
|
||||
|
||||
print('saving model for epoch {}'.format(epoch + 1))
|
||||
if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)):
|
||||
os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1))
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1))
|
||||
# torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1))
|
||||
# torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1))
|
||||
print('epoch {} finished'.format(epoch + 1))
|
||||
|
||||
then = datetime.now()
|
||||
print('time: {}'.format(then))
|
||||
print('time for one epoch: {}'.format(then - now))
|
||||
|
||||
print('training finished')
|
||||
if not os.path.exists(output_dir + 'final_model'):
|
||||
os.mkdir(output_dir + 'final_model')
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(output_dir + 'final_model')
|
||||
# torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt')
|
||||
# torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||