From 00bed53423b04c9bd5edaae11da8c66548b37911 Mon Sep 17 00:00:00 2001 From: Duzeyao <330501241@qq.com> Date: Thu, 7 Nov 2019 15:05:57 +0800 Subject: [PATCH] bugfix --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 7cebdd0..ae001cb 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,7 @@ def main(): if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") - model = DataParallel(model) + model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) multi_gpu = True print('starting training') overall_step = 0