import transformers import torch import os import json import random import argparse import numpy as np from datetime import datetime from torch.nn import DataParallel from tqdm import tqdm ''' 如果训练材料是全部堆在一起不分篇章的话用这个文件 ''' def build_files(raw_data_path, tokenized_data_path, full_tokenizer, num_pieces): with open(raw_data_path, 'r', encoding='utf8') as f: print('reading lines') lines = json.load(f) lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 single = ''.join(lines) len_single = len(single) if not os.path.exists(tokenized_data_path): os.mkdir(tokenized_data_path) for i in tqdm(range(num_pieces)): single_ids = full_tokenizer.convert_tokens_to_ids( full_tokenizer.tokenize(single[len_single // num_pieces * i: len_single // num_pieces * (i + 1)])) with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: for id in single_ids[:-1]: f.write(str(id) + ' ') f.write(str(single_ids[-1])) f.write('\n') print('finish') def main(): parser = argparse.ArgumentParser() parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, help='选择模型参数') parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料') parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False, help='tokenized语料存放位置') parser.add_argument('--raw', action='store_true', help='是否先做tokenize') parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环') parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') parser.add_argument('--fp16', action='store_true', help='混合精度') parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False) parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径') parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径') parser.add_argument('--segment', action='store_true', help='中文以词为单位') args = parser.parse_args() print('args:\n' + args.__repr__()) if args.segment: from tokenizations import tokenization_bert_word_level as tokenization_bert else: from tokenizations import tokenization_bert os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) print('config:\n' + model_config.to_json_string()) n_ctx = model_config.n_ctx full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) full_tokenizer.max_len = 999999 device = 'cuda' if torch.cuda.is_available() else 'cpu' print('using device:', device) raw_data_path = args.raw_data_path tokenized_data_path = args.tokenized_data_path raw = args.raw # 选择是否从零开始构建数据集 epochs = args.epochs batch_size = args.batch_size lr = args.lr warmup_steps = args.warmup_steps log_step = args.log_step stride = args.stride gradient_accumulation = args.gradient_accumulation fp16 = args.fp16 # 不支持半精度的显卡请勿打开 fp16_opt_level = args.fp16_opt_level max_grad_norm = args.max_grad_norm num_pieces = args.num_pieces output_dir = args.output_dir if raw: print('building files') build_files(raw_data_path=raw_data_path, tokenized_data_path=tokenized_data_path, full_tokenizer=full_tokenizer, num_pieces=num_pieces) print('files built') if not args.pretrained_model: model = transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config) else: model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model) model.train() model.to(device) multi_gpu = False full_len = 0 print('calculating total steps') for i in tqdm(range(num_pieces)): with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: full_len += len([int(item) for item in f.read().strip().split()]) total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation) print('total steps = {}'.format(total_steps)) optimizer = transformers.AdamW(model.parameters(), lr=lr, correct_bias=True) scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=total_steps) if fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DataParallel(model) multi_gpu = True print('starting training') running_loss = 0 for epoch in range(epochs): print('epoch {}'.format(epoch + 1)) now = datetime.now() print('time: {}'.format(now)) x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32) random.shuffle(x) piece_num = 0 for i in x: with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: line = f.read().strip() tokens = line.split() tokens = [int(token) for token in tokens] start_point = 0 samples = [] while start_point < len(tokens) - n_ctx: samples.append(tokens[start_point: start_point + n_ctx]) start_point += stride if start_point < len(tokens): samples.append(tokens[len(tokens)-n_ctx:]) random.shuffle(samples) for step in range(len(samples) // batch_size): # prepare data batch = samples[step * batch_size: (step + 1) * batch_size] batch_labels = [] batch_inputs = [] for ids in batch: int_ids_for_labels = [int(x) for x in ids] int_ids_for_inputs = [int(x) for x in ids] batch_labels.append(int_ids_for_labels) batch_inputs.append(int_ids_for_inputs) batch_labels = torch.tensor(batch_labels).long().to(device) batch_inputs = torch.tensor(batch_inputs).long().to(device) # forward pass outputs = model.forward(input_ids=batch_inputs, labels=batch_labels) loss, logits = outputs[:2] # get loss if multi_gpu: loss = loss.mean() if gradient_accumulation > 1: loss = loss / gradient_accumulation # loss backward if fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # optimizer step if (step + 1) % gradient_accumulation == 0: running_loss += loss.item() optimizer.step() optimizer.zero_grad() scheduler.step() if (step + 1) % log_step == 0: print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format( datetime.now().hour, datetime.now().minute, (step + 1) // gradient_accumulation, piece_num, epoch + 1, running_loss / log_step)) running_loss = 0 piece_num += 1 print('saving model for epoch {}'.format(epoch + 1)) if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)): os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1)) model_to_save = model.module if hasattr(model, 'module') else model model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1)) # torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1)) # torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1)) print('epoch {} finished'.format(epoch + 1)) then = datetime.now() print('time: {}'.format(then)) print('time for one epoch: {}'.format(then - now)) print('training finished') if not os.path.exists(output_dir + 'final_model'): os.mkdir(output_dir + 'final_model') model_to_save = model.module if hasattr(model, 'module') else model model_to_save.save_pretrained(output_dir + 'final_model') # torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt') # torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt') if __name__ == '__main__': main()