This commit is contained in:
Duzeyao
2019-11-13 15:58:05 +08:00
parent 5b5dcc15e5
commit 3f1b188e24

View File

@@ -52,7 +52,7 @@ def main():
parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size')
parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率')
parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss设置为gradient accumulation的整数倍')
parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长')
parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累')
parser.add_argument('--fp16', action='store_true', help='混合精度') parser.add_argument('--fp16', action='store_true', help='混合精度')
@@ -221,7 +221,7 @@ def main():
step + 1, step + 1,
piece_num, piece_num,
epoch + 1, epoch + 1,
running_loss / log_step)) running_loss * gradient_accumulation / log_step))
running_loss = 0 running_loss = 0
piece_num += 1 piece_num += 1