import logging import os import sys import time import cv2 import tensorflow.keras.backend as K import numpy as np import tensorflow as tf from tensorflow.keras import utils as keras_utils from src import experiment_base import json def configure_gpus(gpus): # set gpu id and tf settings os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in gpus]) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True K.set_session(tf.Session(config=config)) # loads a saved experiment using the saved parameters. # runs all initialization steps so that we can use the models right away def load_experiment_from_dir(from_dir, exp_class: experiment_base.Experiment, load_n=None, load_epoch=None, log_to_dir=False, # dont log if we are just loading this exp for evaluation do_load_models=True ): with open(os.path.join(from_dir, 'arch_params.json'), 'r') as f: fromdir_arch_params = json.load(f) fromdir_arch_params['exp_dir'] = from_dir with open(os.path.join(from_dir, 'data_params.json'), 'r') as f: fromdir_data_params = json.load(f) exp = exp_class( data_params=fromdir_data_params, arch_params=fromdir_arch_params, prompt_delete_existing=False, prompt_update_name=True, # in case the experiment was renamed log_to_dir=log_to_dir) exp.load_data(load_n=load_n) exp.create_models() if do_load_models: loaded_epoch = exp.load_models(load_epoch) else: loaded_epoch = None return exp, loaded_epoch def run_experiment(exp, run_args, end_epoch, save_every_n_epochs, test_every_n_epochs, ): if run_args.debug: if run_args.epoch is not None: end_epoch = int(run_args.epoch) + 10 else: end_epoch = 10 if hasattr(run_args, 'loadn') and run_args.loadn is None: run_args.loadn = 1 elif not hasattr(run_args, 'loadn'): run_args.loadn = None save_every_n_epochs = 2 test_every_n_epochs = 2 exp.set_debug_mode(True) if run_args.batch_size is None: run_args.batch_size = 8 if not hasattr(run_args, 'ignore_missing'): run_args.ignore_missing = False exp_dir, figures_dir, logs_dir, models_dir = exp.get_dirs() # log to the newly created experiments dir formatter = logging.Formatter( '[%(asctime)s] %(message)s', "%Y-%m-%d %H:%M:%S") lfh = logging.FileHandler( filename=os.path.join(exp_dir, 'training.log')) lsh = logging.StreamHandler(sys.stdout) lfh.setFormatter(formatter) lsh.setFormatter(formatter) lfh.setLevel(logging.DEBUG) lsh.setLevel(logging.DEBUG) file_stdout_logger = logging.getLogger('both') file_stdout_logger.setLevel(logging.DEBUG) file_stdout_logger.addHandler(lfh) file_stdout_logger.addHandler(lsh) file_logger = logging.getLogger('file') file_logger.setLevel(logging.DEBUG) file_logger.addHandler(lfh) # load the dataset. load fewer if debugging exp.load_data(load_n=run_args.loadn) # create models and load existing ones if necessary exp.create_models() start_epoch = exp.load_models(run_args.epoch, stop_on_missing=not run_args.ignore_missing, init_layers=run_args.init_weights) # compile models for training exp.compile_models() if run_args.init_from: exp.init_model_weights(run_args.init_from) exp.create_generators(batch_size=run_args.batch_size) tbw = tf.summary.FileWriter(logs_dir) train_batch_by_batch( exp=exp, batch_size=run_args.batch_size, start_epoch=start_epoch, end_epoch=end_epoch, save_every_n_epochs=save_every_n_epochs, test_every_n_epochs=test_every_n_epochs, tbw=tbw, file_stdout_logger=file_stdout_logger, file_logger=file_logger, run_args=run_args, ) return exp_dir def train_batch_by_batch( exp, batch_size, start_epoch, end_epoch, save_every_n_epochs, test_every_n_epochs, tbw, file_stdout_logger, file_logger, run_args, ): max_n_batch_per_epoch = 1000 # limits each epoch to batch_size * 1000 examples. i think this is ok. n_batch_per_epoch_train = min(max_n_batch_per_epoch, int(np.ceil(exp.get_n_train() / float(batch_size)))) print(exp.get_n_train()) max_printed_examples = 8 print_every = 100000 # set this to be really high at first print_atleast_every = 100 print_atmost = max(1, max_printed_examples / batch_size) # lets say we want 1 new result image every 1 minute print_every_n_seconds = run_args.print_every # save a new model every 20 minutes? seems reasonable auto_save_every_n_epochs = 100 auto_test_every_n_epochs = 100 min_save_every_n_epochs = 10 save_every_n_seconds = 20 * 60 start_time = time.time() # do this once here to flush any setup information to the file exp._reopen_log_file() for e in range(start_epoch, end_epoch + 1): file_stdout_logger.debug('{} training epoch {}/{}'.format(exp.model_name, e, end_epoch + 1)) if e < end_epoch: exp.update_epoch_count(e) pb = keras_utils.Progbar(n_batch_per_epoch_train) printed_count = 0 for bi in range(n_batch_per_epoch_train): joint_loss, joint_loss_names = exp.train_on_batch() batch_count = e * n_batch_per_epoch_train + bi # only log to file on the last batch of training, otherwise we'll have too many messages training_logger = None if bi == n_batch_per_epoch_train - 1: training_logger = file_logger log_losses(pb, tbw, training_logger, joint_loss_names, joint_loss, batch_count) # time how long it takes to do 5 batches if batch_count - start_epoch * n_batch_per_epoch_train == 5: s_per_batch = (time.time() - start_time) / 5. # make this an odd integer in case our experiment is doing # different things on alternating batches, so that we can visualize both print_every = int(np.ceil(print_every_n_seconds / s_per_batch / 2.)) * 2 + 1 auto_save_every_n_epochs = save_every_n_seconds / s_per_batch / n_batch_per_epoch_train if auto_save_every_n_epochs > 50: # if interval is big enough, adjust to multiples of 50 auto_save_every_n_epochs = max(1, int(np.floor(save_every_n_epochs / 50))) * 50 else: auto_save_every_n_epochs = max(1, int(np.floor(save_every_n_epochs / min_save_every_n_epochs))) \ * min_save_every_n_epochs if ((batch_count % print_every == 0 or batch_count % print_atleast_every == 0)) \ and printed_count < print_atmost: results_im = exp.make_train_results_im() cv2.imwrite( os.path.join(exp.figures_dir, 'train_epoch{}_batch{}.jpg'.format(e, bi) ), results_im) printed_count += 1 if batch_count >= 10: file_stdout_logger.debug('Printing every {} batches, ' 'saving every {} and {} epochs, ' 'testing every {}'.format(print_every, auto_save_every_n_epochs, save_every_n_epochs, test_every_n_epochs, )) if (e > 0 and e % auto_save_every_n_epochs == 0 and e > start_epoch) or e == end_epoch or ( e > 0 and e % save_every_n_epochs == 0 and e > start_epoch): exp.save_models(e, iter_count=e * n_batch_per_epoch_train) # TODO: figure out how to flush log file without closing file_stdout_logger.handlers[0].close() # flush our .log file lfh = logging.FileHandler(filename=os.path.join(exp.exp_dir, 'training.log')) file_stdout_logger.handlers[0] = lfh if exp.logger is not None: exp.logger.handlers[0].close() exp._reopen_log_file() tbw.close() # save to disk and then open a new file so that we can read into tensorboard more easily tbw.reopen() if (e % auto_test_every_n_epochs == 0 or e % test_every_n_epochs == 0): file_stdout_logger.debug('{} testing'.format(exp.model_name)) pbt = keras_utils.Progbar(1) test_loss, test_loss_names = exp.test_batches() log_losses(pbt, None, file_logger, test_loss_names, test_loss, e * n_batch_per_epoch_train + bi) results_im = exp.make_test_results_im() if results_im is not None: cv2.imwrite(os.path.join(exp.figures_dir, 'test_epoch{}_batch{}.jpg'.format(e, bi)), results_im) log_losses(None, tbw, file_logger, test_loss_names, test_loss, e * n_batch_per_epoch_train + bi) print('\n\n') def log_losses(progressBar, tensorBoardWriter, logger, loss_names, loss_vals, iter_count): if not isinstance(loss_vals, list): # occurs when model only has one loss loss_vals = [loss_vals] # update the progress bar displayed in stdout if progressBar is not None: progressBar.add(1, values=[(loss_names[i], loss_vals[i]) for i in range(len(loss_vals))]) # write to log using python logging if logger is not None: logger.debug(', '.join(['{}: {}'.format(loss_names[i], loss_vals[i]) for i in range(len(loss_vals))])) # write to tensorboard for pretty plots if tensorBoardWriter is not None: for i in range(len(loss_names)): tensorBoardWriter.add_summary( tf.Summary(value=[tf.Summary.Value(tag=loss_names[i], simple_value=loss_vals[i]), ]), iter_count) if i >= len(loss_vals): break