修复gradient accumulation
This commit is contained in:
@@ -17,7 +17,6 @@
|
|||||||
## 项目状态
|
## 项目状态
|
||||||
|
|
||||||
- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交Issue,PR或是联系作者。
|
- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交Issue,PR或是联系作者。
|
||||||
- 如使用梯度积累,loss计算可能存在bug。
|
|
||||||
|
|
||||||
## 使用方法
|
## 使用方法
|
||||||
|
|
||||||
@@ -61,7 +60,7 @@ python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save
|
|||||||
|
|
||||||
## FP16与Gradient Accumulation支持
|
## FP16与Gradient Accumulation支持
|
||||||
|
|
||||||
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16不收敛,原因不明。
|
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16可能不收敛,原因不明。
|
||||||
|
|
||||||
## 联系作者
|
## 联系作者
|
||||||
|
|
||||||
|
|||||||
10
config/model_config_test.json
Normal file
10
config/model_config_test.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"layer_norm_epsilon": 1e-05,
|
||||||
|
"n_ctx": 64,
|
||||||
|
"n_embd": 128,
|
||||||
|
"n_head": 2,
|
||||||
|
"n_layer": 1,
|
||||||
|
"n_positions": 64,
|
||||||
|
"vocab_size": 13317
|
||||||
|
}
|
||||||
9
train.py
9
train.py
@@ -207,23 +207,22 @@ def main():
|
|||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
||||||
|
|
||||||
# optimizer step
|
# optimizer step
|
||||||
if (step + 1) % gradient_accumulation == 0:
|
if (overall_step + 1) % gradient_accumulation == 0:
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
overall_step += 1
|
|
||||||
if (overall_step + 1) % log_step == 0:
|
|
||||||
tb_writer.add_scalar('loss', loss.item(), overall_step)
|
|
||||||
if (overall_step + 1) % log_step == 0:
|
if (overall_step + 1) % log_step == 0:
|
||||||
|
tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step)
|
||||||
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
|
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
|
||||||
datetime.now().hour,
|
datetime.now().hour,
|
||||||
datetime.now().minute,
|
datetime.now().minute,
|
||||||
step + 1,
|
step + 1,
|
||||||
piece_num,
|
piece_num,
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
running_loss * gradient_accumulation / log_step))
|
running_loss * gradient_accumulation / (log_step / gradient_accumulation)))
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
overall_step += 1
|
||||||
piece_num += 1
|
piece_num += 1
|
||||||
|
|
||||||
print('saving model for epoch {}'.format(epoch + 1))
|
print('saving model for epoch {}'.format(epoch + 1))
|
||||||
|
|||||||
Reference in New Issue
Block a user