Update train.py

This commit is contained in:
Duzeyao
2019-11-13 15:59:20 +08:00
parent 3f1b188e24
commit 2ed6a8d06e

View File

@@ -107,6 +107,7 @@ def main():
min_length = args.min_length min_length = args.min_length
output_dir = args.output_dir output_dir = args.output_dir
tb_writer = SummaryWriter(log_dir=args.writer_dir) tb_writer = SummaryWriter(log_dir=args.writer_dir)
assert log_step % gradient_accumulation == 0
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.mkdir(output_dir) os.mkdir(output_dir)