From d2dfac459767d183fb47c568050134c12af659f3 Mon Sep 17 00:00:00 2001 From: Prajwal Date: Fri, 28 Aug 2020 21:06:39 +0530 Subject: [PATCH] Bug fixes --- hq_wav2lip_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hq_wav2lip_train.py b/hq_wav2lip_train.py index 9761b9f..86db55c 100644 --- a/hq_wav2lip_train.py +++ b/hq_wav2lip_train.py @@ -284,7 +284,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d if global_step % hparams.eval_interval == 0: with torch.no_grad(): - average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc, checkpoint_dir) + average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc) if average_sync_loss < .75: hparams.set_hparam('syncnet_wt', 0.03) @@ -297,7 +297,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d global_epoch += 1 -def eval_model(test_data_loader, global_step, writer, device, model, disc, checkpoint_dir): +def eval_model(test_data_loader, global_step, device, model, disc): eval_steps = 300 print('Evaluating for {} steps'.format(eval_steps)) running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []