diff --git a/train.py b/train.py index 7cebdd0..ae001cb 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,7 @@ def main(): if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") - model = DataParallel(model) + model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) multi_gpu = True print('starting training') overall_step = 0