Remove data parallelism
This commit is contained in:
parent
b3a417f48e
commit
d7259e531b
1 changed files with 2 additions and 4 deletions
|
|
@ -137,11 +137,9 @@ def cosine_loss(a, v, y):
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def train(device, model_single, train_data_loader, test_data_loader, optimizer,
|
def train(device, model, train_data_loader, test_data_loader, optimizer,
|
||||||
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
|
||||||
|
|
||||||
model = nn.DataParallel(model_single)
|
|
||||||
|
|
||||||
global global_step, global_epoch
|
global global_step, global_epoch
|
||||||
resumed_step = global_step
|
resumed_step = global_step
|
||||||
|
|
||||||
|
|
@ -170,7 +168,7 @@ def train(device, model_single, train_data_loader, test_data_loader, optimizer,
|
||||||
|
|
||||||
if global_step == 1 or global_step % checkpoint_interval == 0:
|
if global_step == 1 or global_step % checkpoint_interval == 0:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
model_single, optimizer, global_step, checkpoint_dir, global_epoch)
|
model, optimizer, global_step, checkpoint_dir, global_epoch)
|
||||||
|
|
||||||
if global_step % hparams.syncnet_eval_interval == 0:
|
if global_step % hparams.syncnet_eval_interval == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue