Update train.py
This commit is contained in:
1
train.py
1
train.py
@@ -107,6 +107,7 @@ def main():
|
|||||||
min_length = args.min_length
|
min_length = args.min_length
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
tb_writer = SummaryWriter(log_dir=args.writer_dir)
|
tb_writer = SummaryWriter(log_dir=args.writer_dir)
|
||||||
|
assert log_step % gradient_accumulation == 0
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.mkdir(output_dir)
|
os.mkdir(output_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user