修复gradient accumulation
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
## 项目状态
|
||||
|
||||
- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交Issue,PR或是联系作者。
|
||||
- 如使用梯度积累,loss计算可能存在bug。
|
||||
|
||||
## 使用方法
|
||||
|
||||
@@ -61,7 +60,7 @@ python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save
|
||||
|
||||
## FP16与Gradient Accumulation支持
|
||||
|
||||
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16不收敛,原因不明。
|
||||
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16可能不收敛,原因不明。
|
||||
|
||||
## 联系作者
|
||||
|
||||
|
||||
10
config/model_config_test.json
Normal file
10
config/model_config_test.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"n_ctx": 64,
|
||||
"n_embd": 128,
|
||||
"n_head": 2,
|
||||
"n_layer": 1,
|
||||
"n_positions": 64,
|
||||
"vocab_size": 13317
|
||||
}
|
||||
9
train.py
9
train.py
@@ -207,23 +207,22 @@ def main():
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
||||
|
||||
# optimizer step
|
||||
if (step + 1) % gradient_accumulation == 0:
|
||||
if (overall_step + 1) % gradient_accumulation == 0:
|
||||
running_loss += loss.item()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
overall_step += 1
|
||||
if (overall_step + 1) % log_step == 0:
|
||||
tb_writer.add_scalar('loss', loss.item(), overall_step)
|
||||
if (overall_step + 1) % log_step == 0:
|
||||
tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step)
|
||||
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
|
||||
datetime.now().hour,
|
||||
datetime.now().minute,
|
||||
step + 1,
|
||||
piece_num,
|
||||
epoch + 1,
|
||||
running_loss * gradient_accumulation / log_step))
|
||||
running_loss * gradient_accumulation / (log_step / gradient_accumulation)))
|
||||
running_loss = 0
|
||||
overall_step += 1
|
||||
piece_num += 1
|
||||
|
||||
print('saving model for epoch {}'.format(epoch + 1))
|
||||
|
||||
Reference in New Issue
Block a user