Bug fixes

This commit is contained in:
Prajwal 2020-08-28 21:06:39 +05:30 committed by GitHub
parent 151859f68c
commit d2dfac4597
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -284,7 +284,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d
if global_step % hparams.eval_interval == 0:
with torch.no_grad():
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc, checkpoint_dir)
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
if average_sync_loss < .75:
hparams.set_hparam('syncnet_wt', 0.03)
@ -297,7 +297,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d
global_epoch += 1
def eval_model(test_data_loader, global_step, writer, device, model, disc, checkpoint_dir):
def eval_model(test_data_loader, global_step, device, model, disc):
eval_steps = 300
print('Evaluating for {} steps'.format(eval_steps))
running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []