Files
cs224n_2019/pytorch-bert-code/bert.py
2019-12-13 14:52:07 +08:00

54 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# from pytorch_pretrained_bert import BertModel, BertTokenizer
from transformers import BertModel, BertTokenizer
class Config(object):
"""配置参数"""
def __init__(self, dataset):
self.model_name = 'bert'
self.train_path = dataset + '/data/train.txt'
self.dev_path = dataset + '/data/dev.txt'
self.test_path = dataset + '/data/test.txt'
self.class_list = [x.strip() for x in open(
dataset + '/data/class.txt').readlines()]
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_classes = len(self.class_list)
self.num_epochs = 3
self.batch_size = 128
self.pad_size = 32
self.learning_rate = 5e-5
self.bert_path = './bert'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, input_ids,# 输入的句子
input_mask,# 对padding部分进行mask和句子一个sizepadding部分用0表示[1, 1, 1, 1, 0, 0]
segments_ids
):
_, pooled = self.bert(input_ids, attention_mask=input_mask,token_type_ids=segments_ids)#pooled [batch_size, hidden_size]
out = self.fc(pooled)
return out
def loss(self,outputs,labels):
criterion=F.cross_entropy
loss = criterion(outputs, labels)
return loss