bugfix
This commit is contained in:
2
train.py
2
train.py
@@ -151,7 +151,7 @@ def main():
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = DataParallel(model)
|
||||
model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')])
|
||||
multi_gpu = True
|
||||
print('starting training')
|
||||
overall_step = 0
|
||||
|
||||
Reference in New Issue
Block a user