From 6777c982ab09b4501637f35e52bf7b37a65ec4b5 Mon Sep 17 00:00:00 2001 From: mshahbazi72 Date: Thu, 22 Oct 2020 12:26:46 +0200 Subject: [PATCH] fix saving best test metrics --- train_fns.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train_fns.py b/train_fns.py index 5a4d2a9b..4d92aaec 100644 --- a/train_fns.py +++ b/train_fns.py @@ -172,13 +172,16 @@ def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, # If improved over previous best metric, save approrpiate copy if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])): + + state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean) + state_dict['best_FID'] = min(state_dict['best_FID'], FID) + print('%s improved over previous best, saving checkpoint...' % config['which_best']) utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None) state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies'] - state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean) - state_dict['best_FID'] = min(state_dict['best_FID'], FID) + # Log results to file test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean), - IS_std=float(IS_std), FID=float(FID)) \ No newline at end of file + IS_std=float(IS_std), FID=float(FID))