修复gradient accumulation

This commit is contained in:
Duzeyao
2019-11-20 10:52:50 +08:00
parent 2ed6a8d06e
commit 44d8bc66a2
3 changed files with 15 additions and 7 deletions

View File

@@ -17,7 +17,6 @@
## 项目状态
- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交IssuePR或是联系作者。
- 如使用梯度积累loss计算可能存在bug。
## 使用方法
@@ -61,7 +60,7 @@ python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save
## FP16与Gradient Accumulation支持
- 我在train.py文件中加入了fp16与gradient accumulation支持如果你安装了apex并且知道fp16是什么的话可以修改变量fp16=True来启用。但是目前fp16不收敛原因不明。
- 我在train.py文件中加入了fp16与gradient accumulation支持如果你安装了apex并且知道fp16是什么的话可以修改变量fp16=True来启用。但是目前fp16可能不收敛,原因不明。
## 联系作者

View File

@@ -0,0 +1,10 @@
{
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"n_ctx": 64,
"n_embd": 128,
"n_head": 2,
"n_layer": 1,
"n_positions": 64,
"vocab_size": 13317
}

View File

@@ -207,23 +207,22 @@ def main():
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# optimizer step
if (step + 1) % gradient_accumulation == 0:
if (overall_step + 1) % gradient_accumulation == 0:
running_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
overall_step += 1
if (overall_step + 1) % log_step == 0:
tb_writer.add_scalar('loss', loss.item(), overall_step)
if (overall_step + 1) % log_step == 0:
tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step)
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
datetime.now().hour,
datetime.now().minute,
step + 1,
piece_num,
epoch + 1,
running_loss * gradient_accumulation / log_step))
running_loss * gradient_accumulation / (log_step / gradient_accumulation)))
running_loss = 0
overall_step += 1
piece_num += 1
print('saving model for epoch {}'.format(epoch + 1))