#!/usr/bin/env python # coding: utf8 """ Entrypoint provider for performing model training. USAGE: python -m spleeter train -p /path/to/params """ from functools import partial # pylint: disable=import-error import tensorflow as tf # pylint: enable=import-error from ..audio.adapter import get_audio_adapter from ..dataset import get_training_dataset, get_validation_dataset from ..model import model_fn from ..model.provider import ModelProvider from ..utils.logging import get_logger __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' def _create_estimator(params): """ Creates estimator. :param params: TF params to build estimator from. :returns: Built estimator. """ session_config = tf.compat.v1.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction = 0.45 estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=params['model_dir'], params=params, config=tf.estimator.RunConfig( save_checkpoints_steps=params['save_checkpoints_steps'], tf_random_seed=params['random_seed'], save_summary_steps=params['save_summary_steps'], session_config=session_config, log_step_count_steps=10, keep_checkpoint_max=2)) return estimator def _create_train_spec(params, audio_adapter, audio_path): """ Creates train spec. :param params: TF params to build spec from. :returns: Built train spec. """ input_fn = partial(get_training_dataset, params, audio_adapter, audio_path) train_spec = tf.estimator.TrainSpec( input_fn=input_fn, max_steps=params['train_max_steps']) return train_spec def _create_evaluation_spec(params, audio_adapter, audio_path): """ Setup eval spec evaluating ever n seconds :param params: TF params to build spec from. :returns: Built evaluation spec. """ input_fn = partial( get_validation_dataset, params, audio_adapter, audio_path) evaluation_spec = tf.estimator.EvalSpec( input_fn=input_fn, steps=None, throttle_secs=params['throttle_secs']) return evaluation_spec def entrypoint(arguments, params): """ Command entrypoint. :param arguments: Command line parsed argument as argparse.Namespace. :param params: Deserialized JSON configuration file provided in CLI args. """ audio_adapter = get_audio_adapter(arguments.audio_adapter) audio_path = arguments.audio_path estimator = _create_estimator(params) train_spec = _create_train_spec(params, audio_adapter, audio_path) evaluation_spec = _create_evaluation_spec( params, audio_adapter, audio_path) get_logger().info('Start model training') tf.estimator.train_and_evaluate( estimator, train_spec, evaluation_spec) ModelProvider.writeProbe(params['model_dir']) get_logger().info('Model training done')