update to transformers

This commit is contained in:
Duzeyao
2019-10-25 23:37:34 +08:00
parent f4ab09186c
commit 9b630a53fe
8 changed files with 20 additions and 22 deletions

View File

@@ -1,10 +1,9 @@
import torch
import torch.nn.functional as F
import pytorch_transformers
import os
import argparse
from tqdm import trange
from pytorch_transformers import GPT2LMHeadModel
from transformers import GPT2LMHeadModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 此处设置程序使用哪些显卡
@@ -139,7 +138,6 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
model_config = pytorch_transformers.GPT2Config.from_json_file(args.model_config)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model.to(device)
model.eval()