This commit is contained in:
Duzeyao
2019-11-07 15:05:57 +08:00
parent 787da7a601
commit 00bed53423

View File

@@ -151,7 +151,7 @@ def main():
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = DataParallel(model)
model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')])
multi_gpu = True
print('starting training')
overall_step = 0