update to transformers
This commit is contained in:
6
eval.py
6
eval.py
@@ -1,4 +1,4 @@
|
||||
import pytorch_transformers
|
||||
import transformers
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
||||
|
||||
model_config = pytorch_transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
|
||||
model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
|
||||
print('config:\n' + model_config.to_json_string())
|
||||
|
||||
n_ctx = model_config.n_ctx
|
||||
@@ -97,7 +97,7 @@ def main():
|
||||
print('you need to specify a trained model.')
|
||||
exit(1)
|
||||
else:
|
||||
model = pytorch_transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user