Remove data parallelism

This commit is contained in:
Prajwal 2020-09-15 19:58:44 +05:30 committed by GitHub
parent b3a417f48e
commit d7259e531b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -137,11 +137,9 @@ def cosine_loss(a, v, y):
return loss
def train(device, model_single, train_data_loader, test_data_loader, optimizer,
def train(device, model, train_data_loader, test_data_loader, optimizer,
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
model = nn.DataParallel(model_single)
global global_step, global_epoch
resumed_step = global_step
@ -170,7 +168,7 @@ def train(device, model_single, train_data_loader, test_data_loader, optimizer,
if global_step == 1 or global_step % checkpoint_interval == 0:
save_checkpoint(
model_single, optimizer, global_step, checkpoint_dir, global_epoch)
model, optimizer, global_step, checkpoint_dir, global_epoch)
if global_step % hparams.syncnet_eval_interval == 0:
with torch.no_grad():