Bug fixes
This commit is contained in:
parent
151859f68c
commit
d2dfac4597
1 changed files with 2 additions and 2 deletions
|
|
@ -284,7 +284,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d
|
|||
|
||||
if global_step % hparams.eval_interval == 0:
|
||||
with torch.no_grad():
|
||||
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc, checkpoint_dir)
|
||||
average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
|
||||
|
||||
if average_sync_loss < .75:
|
||||
hparams.set_hparam('syncnet_wt', 0.03)
|
||||
|
|
@ -297,7 +297,7 @@ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, d
|
|||
|
||||
global_epoch += 1
|
||||
|
||||
def eval_model(test_data_loader, global_step, writer, device, model, disc, checkpoint_dir):
|
||||
def eval_model(test_data_loader, global_step, device, model, disc):
|
||||
eval_steps = 300
|
||||
print('Evaluating for {} steps'.format(eval_steps))
|
||||
running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
|
||||
|
|
|
|||
Loading…
Reference in a new issue