python bert code fix

This commit is contained in:
chongjiu.jin
2019-12-24 10:26:47 +08:00
parent 8b915566c0
commit 0f83d3f3cc
5 changed files with 54 additions and 30 deletions

View File

@@ -9,7 +9,23 @@ from utils import get_time_dif
from transformers.optimization import AdamW
# 权重初始化默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in model.named_parameters():
if exclude not in name:
if len(w.size()) < 2:
continue
if 'weight' in name:
if method == 'xavier':
nn.init.xavier_normal_(w)
elif method == 'kaiming':
nn.init.kaiming_normal_(w)
else:
nn.init.normal_(w)
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
pass
def train(config, model, train_iter, dev_iter, test_iter):