import builtins import re from typing import Union, List, Optional import vergeml.glossary as glossary from vergeml.plugins import PLUGINS from vergeml.utils import VergeMLError, did_you_mean from copy import deepcopy from io import IOBase # types class TrainedModel: # pylint: disable=R0903 """A type representing a trained instance of a model.""" class File: # pylint: disable=R0903 """A type representing a file that can be read""" _RESERVED_OPTION_NAMES = { 'version', 'file', 'model', 'samples-dir', 'val-split', 'test-split', 'cache-dir', 'random-seed', 'trainings-dir', 'project-dir', 'cache', 'device', 'device-memory' } _RESERVED_SHORT_OPTION_NAMES = {'v', 'f', 'm'} def option(name, default=None, descr=None, type=None, validate=None, transform=None, long_descr=None, alias=None, short=None, flag=False, yaml_only=False, subcommand=False, command_line=False): """Defines an option. :param name: Name of the option. :param default: Default value of the option. :param type: Type of the option. Can be either a python type or a string representing the type. Supported types are: - int - float - str - bool - NoneType - dict - list - TrainedModel (shortcut '@'') - file - List[<int, float, str>] - Optional[<any of above>] - Union[<any of the above>] :param validate: How to validate the option. Can take the following values: - a string expression using >, <, >= or <=, e.g. '>=0' - a list of possible values - a function which accepts a option definition and a value as options. - None :param transform: Defines a transformation to apply before casting. :param descr: A short description of the option. :param long_descr: A long description. :param flag: For boolean options only: use short form --x for True :param short: An optional short flag, like -oadam :param yaml_only: When true, can only be set via yaml file. :param command_line:When true, this argument is only relevant when running on the command line :param subcommand: If true, the option is a subcommand and is parsed as command:subcommand """ def decorator(o): if o.__name__ not in ('ValidateDevice', 'ValidateData'): assert name not in _RESERVED_OPTION_NAMES, "Invalid option name: {} - name is reserved.".format(name) assert short not in _RESERVED_SHORT_OPTION_NAMES, \ "Invalid short option name {} for option: {} - name is reserved.".format(short, name) assert getattr(o, _CMD_META_KEY, None) is None, _DECORATORS_WRONG_ORDER if not hasattr(o, _OPTIONS_META_KEY): setattr(o, _OPTIONS_META_KEY, []) options = getattr(o, _OPTIONS_META_KEY) option = Option(name=name, default=default, type=type, validate=validate, transform=transform, descr=descr, long_descr=long_descr, alias=alias, short=short, flag=flag, yaml_only=yaml_only, command_line=command_line, subcommand=subcommand) options.append(option) return o return decorator _OPTIONS_META_KEY = '__vergeml_options__' _CMD_META_KEY = '__vergeml_command__' _DECORATORS_WRONG_ORDER = """ You must first define the command and then the parameters, for example: @command('train') @param('learning-rate') def train(): ... """.strip() _VALIDATE_REGEX = r"^(<|>|>=|<=)([0-9][0-9]*(\.[0-9]*)?)$" class Option: def __init__(self, name, default=None, type=None, validate=None, transform=None, descr=None, long_descr=None, alias=None, short=None, flag=False, yaml_only=False, command_line=False, subcommand=False, plugins=PLUGINS): """Defines an option. See the documentation of the function option. """ if isinstance(validate, str): for val in validate.split(","): val = val.strip() assert re.match(_VALIDATE_REGEX, val) self.name = name self.default = default # as a shortcut, type can be a string. if isinstance(type, str): # @ is a shortcut for a trained model type = type.replace("@", "TrainedModel") # convert the string to an object representing the type type = eval(type) # pylint: disable=W0123 self.type = type if not self.type and default is not None: self.type = builtins.type(default) self.validate = validate self.transform = transform self.descr = descr or glossary.short_param_descr(name) self.long_descr = long_descr or glossary.long_descr(name) self.alias = alias self.plugins = plugins self.short = short self.flag = flag self.yaml_only = yaml_only self.command_line = command_line self.subcommand = subcommand @staticmethod def discover(o, plugins=PLUGINS): res = [] if hasattr(o, _OPTIONS_META_KEY): res = deepcopy(getattr(o, _OPTIONS_META_KEY)) for r in res: r.plugins = plugins return res def _invalid_value(self, value, suggestion=None): return VergeMLError(f"Invalid value for option {self.name}.", suggestion, hint_type='value', hint_key=self.name) def has_type(self, *types): """Check if the option is of a type in the list types""" return _has_type(self.type, *types) def validate_value(self, value): if not self.validate: return if not self.is_required() and value in (None, 'null', 'Null', 'NULL'): return if isinstance(self.validate, (tuple, list)) and value not in self.validate: suggestion = None if all(map(lambda e: isinstance(e, str), self.validate)): suggestion = did_you_mean(self.validate, value) raise self._invalid_value(value, suggestion) elif callable(self.validate): self.validate(self, value) elif isinstance(self.validate, str): for validate in self.validate.split(","): validate = validate.strip() try: value = float(value) except ValueError: raise self._invalid_value(value) op, num_str = re.match(_VALIDATE_REGEX, validate).group(1,2) num = float(num_str) if op == '>': if not value > num: raise self._invalid_value(value, f"Must be greater than {num_str}") elif op == '<': if not value < num: raise self._invalid_value(value, f"Must be less than {num_str}") if op == '>=': if not value >= num: raise self._invalid_value(value, f"Must be greater or equal to {num_str}") elif op == '<=': if not value <= num: raise self._invalid_value(value, f"Must be less than or equal to {num_str}") def cast_value(self, value, type_=None): # pylint: disable=R0915,R0911,R0912 """Cast value to the type of option. """ type_ = type_ or self.type if not type_: return value if _has_type(type_, '@', 'Optional[@]') and isinstance(value, (str, int, float, bool)): return str(value) if _has_type(type_, 'Optional[@]') and value is None: return value if _has_type(type_, 'List[@]'): if isinstance(value, list) and all(map(lambda e: isinstance(e, str), value)): return value raise ValueError("Could not cast to trained model.") if _has_type(type_, '@', 'Optional[@]'): raise ValueError("Could not cast to trained model.") if _has_type(type_, 'File'): if isinstance(value, str): return value raise ValueError("Could not cast to file.") if _has_type(type_, 'Optional[File]'): if isinstance(value, str) or value is None: return value raise ValueError("Could not cast to file") if _has_type(type_, 'List[File]'): if isinstance(value, list) and all(map(lambda e: isinstance(e, str), value)): return value raise ValueError("Could not cast to file") if _has_type(type_, int): try: if isinstance(value, (int, float, str)) and not isinstance(value, bool): return int(value) raise ValueError("Could not cast to int") except ValueError: raise self._invalid_value(value) if _has_type(type_, float): try: if isinstance(value, (int, float, str)) and not isinstance(value, bool): return float(value) raise ValueError("Could not cast to float") except ValueError: raise self._invalid_value(value) if _has_type(type_, str): try: if isinstance(value, (int, float, str)) and not isinstance(value, bool): return str(value) raise ValueError("Could not cast to str") except ValueError: raise self._invalid_value(value) if _has_type(type_, bool): if isinstance(value, bool): return value if value in ('y', 'Y', 'yes', 'Yes', 'YES', 'on', 'On', 'ON', 'true', 'True', 'TRUE'): return True if value in ('n', 'N', 'no', 'No', 'NO', 'off', 'Off', 'OFF', 'false', 'False', 'FALSE'): return False raise self._invalid_value(value) if type_ == type(None): if isinstance(value, type(None)): return value if isinstance(value, str) and value in ('null', 'Null', 'NULL'): return None raise self._invalid_value(value) if _has_type(type_, dict): if not isinstance(value, dict): raise self._invalid_value(value) return value if _has_type(type_, list): if not isinstance(value, list): raise self._invalid_value(value) if hasattr(type_, '__origin__'): return [self.cast_value(i, type_.__args__[0]) for i in value] return value if getattr(type_, '__origin__', None) == Union: res = None found = False for typ in type_.__args__: try: res = self.cast_value(value, typ) found = True break except VergeMLError: pass if not found: raise self._invalid_value(value) return res # Unknown type assert False return None def transform_value(self, value): if self.transform: return self.transform(value) else: return value def has_optional_type(self): """Check if the type of the option permits None values. """ return getattr(self.type, '__origin__', None) == Union and \ type(None) in self.type.__args__ # pylint: disable=C0123 def is_optional(self): return self.default is not None or self.has_optional_type() def is_required(self): return not self.is_optional() def is_at_option(self): return self.name.startswith("@") def is_argument_option(self): return self.name.startswith("<") and self.name.endswith(">") def _type_descr(self, tp): tp_descr = "" if tp == TrainedModel: tp_descr = "trained model" elif tp == Optional[TrainedModel]: tp_descr = "optional trained model" elif tp == File: tp_descr = "file" elif tp == Optional[File]: tp_descr = "optional file" elif tp == List[File]: tp_descr = "a list of files" elif hasattr(tp, '__origin__'): if tp.__origin__ in (list, List): tp_descr = "a list of " + self._type_descr(tp.__args__[0]) elif tp.__origin__ == Union: if len(tp.__args__) == 2 and type(None) in tp.__args__: other = list(filter(lambda t: not isinstance(t, type(None)), tp.__args__))[0] tp_descr = 'optional ' + self._type_descr(other) else: names = [self._type_descr(t) for t in tp.__args__] if len(names) <= 2: tp_descr = " or ".join(names) else: tp_descr = ", ".join(names[:-1]) tp_descr += " or " + names[-1] elif tp: if tp == str: tp_descr = "string" else: tp_descr = getattr(tp, '__name__', str(tp)) return tp_descr def human_type(self): tp_descr = self._type_descr(self.type) if tp_descr and isinstance(self.validate, str): tp_descr += " " + self.validate elif tp_descr and isinstance(self.validate, (tuple, list)): tp_descr = "one of (" + ", ".join(map(lambda e: str(e), self.validate)) + ")" if self.default: tp_descr += ", default: " + str(self.default) return tp_descr def _has_type(type_, *types): for typ in types: if isinstance(typ, str): typ = typ.replace('@', 'TrainedModel') typ = eval(typ) # pylint: disable=W0123 if type_ == typ: return True if typ in (list, List) and getattr(type_, '__name__', None) == 'List': return True return False