python bert code

This commit is contained in:
chongjiu.jin
2019-12-13 14:52:07 +08:00
parent b68d3f1152
commit 6229486d35
8 changed files with 507 additions and 0 deletions

53
pytorch-bert-code/bert.py Normal file
View File

@@ -0,0 +1,53 @@
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# from pytorch_pretrained_bert import BertModel, BertTokenizer
from transformers import BertModel, BertTokenizer
class Config(object):
"""配置参数"""
def __init__(self, dataset):
self.model_name = 'bert'
self.train_path = dataset + '/data/train.txt'
self.dev_path = dataset + '/data/dev.txt'
self.test_path = dataset + '/data/test.txt'
self.class_list = [x.strip() for x in open(
dataset + '/data/class.txt').readlines()]
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_classes = len(self.class_list)
self.num_epochs = 3
self.batch_size = 128
self.pad_size = 32
self.learning_rate = 5e-5
self.bert_path = './bert'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, input_ids,# 输入的句子
input_mask,# 对padding部分进行mask和句子一个sizepadding部分用0表示[1, 1, 1, 1, 0, 0]
segments_ids
):
_, pooled = self.bert(input_ids, attention_mask=input_mask,token_type_ids=segments_ids)#pooled [batch_size, hidden_size]
out = self.fc(pooled)
return out
def loss(self,outputs,labels):
criterion=F.cross_entropy
loss = criterion(outputs, labels)
return loss