python bert code fix
This commit is contained in:
@@ -9,7 +9,23 @@ from utils import get_time_dif
|
||||
from transformers.optimization import AdamW
|
||||
|
||||
|
||||
|
||||
# 权重初始化,默认xavier
|
||||
def init_network(model, method='xavier', exclude='embedding', seed=123):
|
||||
for name, w in model.named_parameters():
|
||||
if exclude not in name:
|
||||
if len(w.size()) < 2:
|
||||
continue
|
||||
if 'weight' in name:
|
||||
if method == 'xavier':
|
||||
nn.init.xavier_normal_(w)
|
||||
elif method == 'kaiming':
|
||||
nn.init.kaiming_normal_(w)
|
||||
else:
|
||||
nn.init.normal_(w)
|
||||
elif 'bias' in name:
|
||||
nn.init.constant_(w, 0)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def train(config, model, train_iter, dev_iter, test_iter):
|
||||
|
||||
Reference in New Issue
Block a user