update to transformers

This commit is contained in:
Duzeyao
2019-10-25 23:37:34 +08:00
parent f4ab09186c
commit 9b630a53fe
8 changed files with 20 additions and 22 deletions

View File

@@ -1,4 +1,4 @@
import pytorch_transformers
import transformers
import torch
import os
import json
@@ -65,7 +65,7 @@ def main():
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
model_config = pytorch_transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
print('config:\n' + model_config.to_json_string())
n_ctx = model_config.n_ctx
@@ -97,7 +97,7 @@ def main():
print('you need to specify a trained model.')
exit(1)
else:
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
model.eval()
model.to(device)