Remove data parallelism
This commit is contained in:
parent
b3a417f48e
commit
d7259e531b
1 changed files with 2 additions and 4 deletions
|
|
@ -137,11 +137,9 @@ def cosine_loss(a, v, y):
|
|||
|
||||
return loss
|
||||
|
||||
def train(device, model_single, train_data_loader, test_data_loader, optimizer,
|
||||
def train(device, model, train_data_loader, test_data_loader, optimizer,
|
||||
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
||||
|
||||
model = nn.DataParallel(model_single)
|
||||
|
||||
global global_step, global_epoch
|
||||
resumed_step = global_step
|
||||
|
||||
|
|
@ -170,7 +168,7 @@ def train(device, model_single, train_data_loader, test_data_loader, optimizer,
|
|||
|
||||
if global_step == 1 or global_step % checkpoint_interval == 0:
|
||||
save_checkpoint(
|
||||
model_single, optimizer, global_step, checkpoint_dir, global_epoch)
|
||||
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
||||
|
||||
if global_step % hparams.syncnet_eval_interval == 0:
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue