104 lines
4.2 KiB
Python
104 lines
4.2 KiB
Python
|
|
# coding: UTF-8
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from sklearn import metrics
|
|||
|
|
import time
|
|||
|
|
from utils import get_time_dif
|
|||
|
|
from transformers.optimization import AdamW
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def train(config, model, train_iter, dev_iter, test_iter):
|
|||
|
|
start_time = time.time()
|
|||
|
|
model.train()
|
|||
|
|
param_optimizer = list(model.named_parameters())
|
|||
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|||
|
|
optimizer_grouped_parameters = [
|
|||
|
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
|||
|
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
|
|||
|
|
# optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
|||
|
|
optimizer = AdamW(optimizer_grouped_parameters,
|
|||
|
|
lr=config.learning_rate,
|
|||
|
|
)
|
|||
|
|
total_batch = 0 # 记录进行到多少batch
|
|||
|
|
dev_best_loss = float('inf')
|
|||
|
|
last_improve = 0 # 记录上次验证集loss下降的batch数
|
|||
|
|
flag = False # 记录是否很久没有效果提升
|
|||
|
|
model.train()
|
|||
|
|
for epoch in range(config.num_epochs):
|
|||
|
|
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
|
|||
|
|
for i, (trains, labels) in enumerate(train_iter):
|
|||
|
|
outputs = model(trains)
|
|||
|
|
model.zero_grad()
|
|||
|
|
loss = F.cross_entropy(outputs, labels)
|
|||
|
|
loss.backward()
|
|||
|
|
optimizer.step()
|
|||
|
|
if total_batch % 100 == 0:
|
|||
|
|
# 每多少轮输出在训练集和验证集上的效果
|
|||
|
|
true = labels.data.cpu()
|
|||
|
|
predic = torch.max(outputs.data, 1)[1].cpu()
|
|||
|
|
train_acc = metrics.accuracy_score(true, predic)
|
|||
|
|
dev_acc, dev_loss = evaluate(config, model, dev_iter)
|
|||
|
|
if dev_loss < dev_best_loss:
|
|||
|
|
dev_best_loss = dev_loss
|
|||
|
|
torch.save(model.state_dict(), config.save_path)
|
|||
|
|
improve = '*'
|
|||
|
|
last_improve = total_batch
|
|||
|
|
else:
|
|||
|
|
improve = ''
|
|||
|
|
time_dif = get_time_dif(start_time)
|
|||
|
|
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
|
|||
|
|
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
|
|||
|
|
model.train()
|
|||
|
|
total_batch += 1
|
|||
|
|
if total_batch - last_improve > config.require_improvement:
|
|||
|
|
# 验证集loss超过1000batch没下降,结束训练
|
|||
|
|
print("No optimization for a long time, auto-stopping...")
|
|||
|
|
flag = True
|
|||
|
|
break
|
|||
|
|
if flag:
|
|||
|
|
break
|
|||
|
|
test(config, model, test_iter)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test(config, model, test_iter):
|
|||
|
|
# test
|
|||
|
|
model.load_state_dict(torch.load(config.save_path))
|
|||
|
|
model.eval()
|
|||
|
|
start_time = time.time()
|
|||
|
|
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
|
|||
|
|
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
|
|||
|
|
print(msg.format(test_loss, test_acc))
|
|||
|
|
print("Precision, Recall and F1-Score...")
|
|||
|
|
print(test_report)
|
|||
|
|
print("Confusion Matrix...")
|
|||
|
|
print(test_confusion)
|
|||
|
|
time_dif = get_time_dif(start_time)
|
|||
|
|
print("Time usage:", time_dif)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def evaluate(config, model, data_iter, test=False):
|
|||
|
|
model.eval()
|
|||
|
|
loss_total = 0
|
|||
|
|
predict_all = np.array([], dtype=int)
|
|||
|
|
labels_all = np.array([], dtype=int)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
for texts, labels in data_iter:
|
|||
|
|
outputs = model(texts)
|
|||
|
|
loss = F.cross_entropy(outputs, labels)
|
|||
|
|
loss_total += loss
|
|||
|
|
labels = labels.data.cpu().numpy()
|
|||
|
|
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
|
|||
|
|
labels_all = np.append(labels_all, labels)
|
|||
|
|
predict_all = np.append(predict_all, predic)
|
|||
|
|
|
|||
|
|
acc = metrics.accuracy_score(labels_all, predict_all)
|
|||
|
|
if test:
|
|||
|
|
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
|
|||
|
|
confusion = metrics.confusion_matrix(labels_all, predict_all)
|
|||
|
|
return acc, loss_total / len(data_iter), report, confusion
|
|||
|
|
return acc, loss_total / len(data_iter)
|