from base.trainer import BaseTrain import tensorflow as tf from models.model import RawModel from data_loader.data_loader import TFRecordDataLoader from typing import Callable class RawTrainer(BaseTrain): def __init__( self, config: dict, model: RawModel, train: TFRecordDataLoader, val: TFRecordDataLoader, pred: TFRecordDataLoader, ) -> None: """ This function will generally remain unchanged, it is used to train and export the model. The only part which may change is the run configuration, and possibly which execution to use (training, eval etc) :param config: global configuration :param model: input function used to initialise model :param train: the training dataset :param val: the evaluation dataset :param pred: the prediction dataset """ super().__init__(config, model, train, val, pred) def run(self) -> None: # allow memory usage to me scaled based on usage config = tf.ConfigProto() config.gpu_options.allow_growth = True # get number of steps required for one pass of data steps_pre_epoch = len(self.train) / self.config["train_batch_size"] # save_checkpoints_steps is number of batches before eval run_config = tf.estimator.RunConfig( session_config=config, save_checkpoints_steps=steps_pre_epoch * 10, # number of batches before eval/checkpoint log_step_count_steps=steps_pre_epoch, # number of steps in epoch ) # set output directory run_config = run_config.replace(model_dir=self.config["job_dir"]) # intialise the estimator with your model estimator = tf.estimator.Estimator(model_fn=self.model.model, config=run_config) # create train and eval specs for estimator, it will automatically convert the tf.Dataset into an input_fn train_spec = tf.estimator.TrainSpec( lambda: self.train.input_fn(), max_steps=self.config["num_epochs"] * steps_pre_epoch, ) eval_spec = tf.estimator.EvalSpec(lambda: self.val.input_fn()) # initialise a wrapper to do training and evaluation, this also handles exporting checkpoints/tensorboard info tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) # after training export the final model for use in tensorflow serving self._export_model(estimator, self.config["export_path"]) # get results after training and exporting model self._predict(estimator, self.pred.input_fn) def _export_model( self, estimator: tf.estimator.Estimator, save_location: str ) -> None: """ Used to export your model in a format that can be used with Tf.Serving :param estimator: your estimator function """ # this should match the input shape of your model # TODO: update this to your input used in prediction/serving x1 = tf.feature_column.numeric_column( "input", shape=[self.config["batch_size"], 28, 28, 1] ) # create a list in case you have more than one input feature_columns = [x1] feature_spec = tf.feature_column.make_parse_example_spec(feature_columns) export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( feature_spec ) # export the saved model estimator.export_savedmodel(save_location, export_input_fn) def _predict(self, estimator: tf.estimator.Estimator, pred_fn: Callable) -> list: """ Function to yield prediction results from the model :param estimator: your estimator function :param pred_fn: input_fn associated with prediction dataset :return: a list containing a prediction for each batch in the dataset """ pass