# Copyright 2017 Natural Language Processing Group, Nanjing University, zhaocq.nlp@gmail.com. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Abstract base class for objects that are configurable using a parameters dictionary. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import copy import os import six import yaml import tensorflow as tf from tensorflow import gfile from njunmt.utils.constants import Constants, ModeKeys from njunmt.utils.misc import open_file class abstractstaticmethod(staticmethod): # pylint: disable=C0111,C0103 # """Decorates a method as abstract and static""" __slots__ = () def __init__(self, function): super(abstractstaticmethod, self).__init__(function) function.__isabstractmethod__ = True __isabstractmethod__ = True def _toggle_dropout(params, mode): """ Disable dropout probability during EVAL/INFER mode. Args: params: A dictionary of parameters. mode: A mode. Returns: A result dictionary. """ params = copy.deepcopy(params) if mode != ModeKeys.TRAIN: for key, val in params.items(): if type(val) is dict: params[key] = _toggle_dropout(params[key], mode) elif "dropout" in key: params[key] = 1.0 if "keep" in key else 0. return params def _params_to_stringlist(params, prefix=" "): """ Convert a dictionary/list of parameters to a format string. Args: params: A dictionary/list of parameters. prefix: A string. Returns: A format string. Raises: ValueError: if unknown type of `params`. """ param_list = [] if isinstance(params, dict): for key, val in params.items(): param_list.append(prefix + key + ": ") if isinstance(val, dict): param_list.extend(_params_to_stringlist(val, prefix + " ")) else: param_list[-1] += str(val) elif isinstance(params, list): prefix += " " for item in params: for idx, (key, val) in enumerate(item.items()): if idx == 0: newprefix = copy.deepcopy(prefix[:-2]) newprefix += "- " param_list.append(newprefix + key + ": ") else: param_list.append(prefix + key + ": ") if isinstance(val, dict): param_list.extend(_params_to_stringlist(val, prefix + " ")) else: param_list[-1] += str(val) else: raise ValueError("Unrecognized type of params: {}".format(str(params))) return param_list def define_tf_flags(args): """ Defines tf FLAGS. Args: args: A dict, with format: {arg_name: [type, default_val, helper]} Returns: tf FLAGS. """ for key, val in args.items(): eval("tf.flags.DEFINE_{}".format(val[0]))(key, val[1], val[2]) return tf.flags.FLAGS def update_configs_from_flags(model_configs, tf_flags, flag_keys): """ Replaces `model_configs` with options defined in `tf_flags`. Args: model_configs: A dict. tf_flags: tf FLAGS. flag_keys: A set of keys. Returns: The updated dict. """ def _update(mc, param_name): param_str = getattr(tf_flags, param_name) if param_str is None: return mc params = yaml.load(param_str) if params is None: return mc return deep_merge_dict(model_configs, {param_name: params}) for key in flag_keys: model_configs = _update(model_configs, key) return model_configs def load_from_config_path(config_paths): """ Loads configurations from files of yaml format. Args: config_paths: A string (each file name is seperated by ",") or a list of strings (file names). Returns: A dictionary of model configurations, parsed from config files. """ if isinstance(config_paths, six.string_types): config_paths = config_paths.strip().split(",") assert isinstance(config_paths, list) or isinstance(config_paths, tuple) model_configs = dict() for config_path in config_paths: config_path = config_path.strip() if not config_path: continue if not gfile.Exists(config_path): raise OSError("config file does not exist: {}".format(config_path)) config_path = os.path.abspath(config_path) tf.logging.info("loading configurations from {}".format(config_path)) with open_file(config_path, mode="r") as config_file: config_flags = yaml.load(config_file) model_configs = deep_merge_dict(model_configs, config_flags) return model_configs def maybe_load_yaml(item): """Parses `item` only if it is a string. If `item` is a dictionary it is returned as-is. Args: item: Returns: A dictionary. Raises: ValueError: if unknown type of `item`. """ if isinstance(item, six.string_types): return yaml.load(item) elif isinstance(item, dict): return item else: raise ValueError("Got {}, expected string or dict", type(item)) def deep_merge_dict(dict_x, dict_y, path=None): """ Recursively merges dict_y into dict_x. Args: dict_x: A dict. dict_y: A dict. path: Returns: An updated dict of dict_x """ if path is None: path = [] for key in dict_y: if key in dict_x: if isinstance(dict_x[key], dict) and isinstance(dict_y[key], dict): deep_merge_dict(dict_x[key], dict_y[key], path + [str(key)]) elif dict_x[key] == dict_y[key]: pass # same leaf value else: dict_x[key] = dict_y[key] else: dict_x[key] = dict_y[key] return dict_x def parse_params(params, default_params): """Parses parameter values to the types defined by the default parameters. Default parameters are used for missing values. Args: params: A dict. default_params: A dict to provide parameter structure and missing values. Returns: A updated dict. """ # Cast parameters to correct types if params is None: params = {} result = copy.deepcopy(default_params) for key, value in params.items(): # If param is unknown, drop it to stay compatible with past versions if key not in default_params: raise ValueError("{} is not a valid model parameter".format(key)) # Param is a dictionary if isinstance(value, dict): default_dict = default_params[key] if not isinstance(default_dict, dict): raise ValueError("{} should not be a dictionary".format(key)) if default_dict: value = parse_params(value, default_dict) else: # If the default is an empty dict we do not typecheck it # and assume it's done downstream pass if value is None: continue if default_params[key] is None: result[key] = value else: result[key] = type(default_params[key])(value) return result def print_params(title, params): """ Prints parameters. Args: title: A string. params: A dict. """ tf.logging.info(title) for info in _params_to_stringlist(params): tf.logging.info(info) def update_infer_params( model_configs, beam_size=None, maximum_labels_length=None, length_penalty=None): """ Resets inference-specific parameters. Args: model_configs: A dictionary of all model configurations. beam_size: The beam width, if provided, pass it to `model_configs`'s "model_params". maximum_labels_length: The maximum length of sequence that model generates, if provided, pass it to `model_configs`'s "model_params". length_penalty: The length penalty, if provided, pass it to `model_configs`'s "model_params". Returns: An updated dict. """ if beam_size is not None: model_configs["model_params"]["inference.beam_size"] = beam_size if maximum_labels_length is not None: model_configs["model_params"]["inference.maximum_labels_length"] = maximum_labels_length if length_penalty is not None: model_configs["model_params"]["inference.length_penalty"] = length_penalty return model_configs def update_eval_metric( model_configs, metric): """ Resets evaluation-specific parameters. Args: model_configs: A dictionary of all model configurations. metric: A string. Returns: A tuple `(updated_dict, metric_str)`. """ if "modality.target.params" in model_configs["model_params"]: metric_str = model_configs["model_params"]["modality.target.params"]["loss"] else: metric_str = model_configs["model_params"]["modality.params"]["loss"] if metric is not None: metric_str = metric if "modality.target.params" in model_configs["model_params"]: model_configs["model_params"]["modality.target.params"]["loss"] = metric_str else: model_configs["model_params"]["modality.params"]["loss"] = metric_str return model_configs, metric_str @six.add_metaclass(abc.ABCMeta) class Configurable(object): """ Interface for all classes that are configurable via a parameters dictionary. """ def __init__(self, params, mode, name=None, verbose=True): """ Initializes a class. Args: params: A dict of parameters. mode: A mode. name: The name of the object. verbose: Print object parameters if set True. """ self._params = parse_params(params, self.default_params()) self._params = _toggle_dropout(self.params, mode) self._verbose = verbose self._name = name self._mode = mode if verbose: print_params("Parameters for {} under mode={}:" .format(self.__class__.__name__, self.mode), self.params) self._check_parameters() def _check_parameters(self): """ Checks availability of parameters. """ pass @property def name(self): """ Returns the name of the object. """ return self._name @name.setter def name(self, val): """ Set the name. """ self._name = val @property def verbose(self): """ Returns the verbose property. """ return self._verbose @verbose.setter def verbose(self, val): """ Set the verbose property. """ self._verbose = val @property def mode(self): """Returns the mode. """ return self._mode @property def params(self): """Returns a dictionary of parsed parameters. """ return self._params @abstractstaticmethod def default_params(): """Returns a dictionary of default parameters. The default parameters are used to define the expected type of passed parameters. Missing parameter values are replaced with the defaults returned by this method. """ raise NotImplementedError class ModelConfigs: """ A class for dumping and loading model configurations. """ @staticmethod def dump(model_config, output_dir): """ Dumps model configurations. Args: model_config: A dict. output_dir: A string, the output directory. """ model_config_filename = os.path.join(output_dir, Constants.MODEL_CONFIG_YAML_FILENAME) if not gfile.Exists(output_dir): gfile.MakeDirs(output_dir) with open_file(model_config_filename, mode="w") as file: yaml.dump(model_config, file) @staticmethod def load(model_dir): """ Loads model configurations. Args: model_dir: A string, the directory. Returns: A dict. """ model_config_filename = os.path.join(model_dir, Constants.MODEL_CONFIG_YAML_FILENAME) if not gfile.Exists(model_config_filename): raise OSError("Fail to find model config file: %s" % model_config_filename) with open_file(model_config_filename, mode="r") as file: model_configs = yaml.load(file) return model_configs