From d7259e531bf984c6af1eef6ea4511923694c1ac6 Mon Sep 17 00:00:00 2001 From: Prajwal Date: Tue, 15 Sep 2020 19:58:44 +0530 Subject: [PATCH] Remove data parallelism --- color_syncnet_train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/color_syncnet_train.py b/color_syncnet_train.py index 8ae099e..afa0054 100644 --- a/color_syncnet_train.py +++ b/color_syncnet_train.py @@ -137,11 +137,9 @@ def cosine_loss(a, v, y): return loss -def train(device, model_single, train_data_loader, test_data_loader, optimizer, +def train(device, model, train_data_loader, test_data_loader, optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None): - model = nn.DataParallel(model_single) - global global_step, global_epoch resumed_step = global_step @@ -170,7 +168,7 @@ def train(device, model_single, train_data_loader, test_data_loader, optimizer, if global_step == 1 or global_step % checkpoint_interval == 0: save_checkpoint( - model_single, optimizer, global_step, checkpoint_dir, global_epoch) + model, optimizer, global_step, checkpoint_dir, global_epoch) if global_step % hparams.syncnet_eval_interval == 0: with torch.no_grad():