bugfix
This commit is contained in:
2
train.py
2
train.py
@@ -151,7 +151,7 @@ def main():
|
|||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
model = DataParallel(model)
|
model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')])
|
||||||
multi_gpu = True
|
multi_gpu = True
|
||||||
print('starting training')
|
print('starting training')
|
||||||
overall_step = 0
|
overall_step = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user