This commit is contained in:
Duzeyao
2019-11-07 15:05:57 +08:00
parent 787da7a601
commit 00bed53423

View File

@@ -151,7 +151,7 @@ def main():
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
model = DataParallel(model) model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')])
multi_gpu = True multi_gpu = True
print('starting training') print('starting training')
overall_step = 0 overall_step = 0