# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import typing as tp from pathlib import Path from concurrent import futures from nevergrad.common.typetools import PathLike from . import utils from . import core from . import plotting # pylint: disable=too-many-arguments def launch(experiment: str, num_workers: int = 1, seed: tp.Optional[int] = None, cap_index: tp.Optional[int] = None, output: tp.Optional[PathLike] = None) -> Path: """Launch experiment with given names and selection modulo max_index can be specified to provide a limited number of settings """ # create the data csvpath = Path(experiment + ".csv") if output is None else Path(output) if num_workers == 1: df = core.compute(experiment, cap_index=cap_index, seed=seed) else: with futures.ProcessPoolExecutor(max_workers=num_workers) as executor: df = core.compute(experiment, seed=seed, cap_index=cap_index, executor=executor, num_workers=num_workers) # save data to csv try: core.save_or_append_to_csv(df, csvpath) except Exception: # pylint: disable=broad-except csvpath = Path(experiment + ".csv") print(f"Failed to save to {output}, falling back to {csvpath}") core.save_or_append_to_csv(df, csvpath) else: print(f"Saved data to {csvpath}") return csvpath def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description='Run an experiment and create a result csv file.') parser.add_argument('experiment', type=str, help='name of an experiment registered in the experiments registry') parser.add_argument('--seed', type=int, default=None, help="Use a seed for reproducibility (for generators which take care of seeding)") parser.add_argument('--cap_index', type=int, default=None, help="Stop after generationg/running settings #cap_index") parser.add_argument('--output', type=str, default=None, help="Output path for the CSV file (default: <experiment>.csv). Existing files are appended") parser.add_argument('--imports', type=str, default=None, help="Comma-separated list of file paths with additional experiment(s) and/or optimizer(s) definitions") parser.add_argument('--num_workers', type=int, default=1, help="Numbers of workers to use for the computation (splits the job in chunks)") parser.add_argument('--repetitions', type=int, default=1, help="Number of repetitions to perform for the experiment plan (seeds will be incremented)") parser.add_argument('--plot', nargs="?", default=False, const=True, help="Creates the corresponding plots if present (provide a path, or folder <experiment>_plots will be used)") return parser.parse_args() def repeated_launch(experiment: str, num_workers: int = 1, seed: tp.Optional[int] = None, cap_index: tp.Optional[int] = None, output: tp.Optional[PathLike] = None, plot: tp.Union[bool, PathLike] = False, imports: tp.Optional[tp.List[PathLike]] = None, repetitions: int = 1) -> None: """Launch experiment with given names and selection module max_index can be specified to provide a limited number of settings This repeats the plan several times and increments the seed. """ # start by importing additional content if imports is not None: assert isinstance(imports, (tuple, list)) for path in imports: core.import_additional_module(path) # then run multiple times csvpath = Path("default.csv") for k in range(repetitions): print(f"Starting repetition {k +1} / {repetitions}") csvpath = launch(experiment, num_workers=num_workers, cap_index=cap_index, output=output, seed=None if seed is None else seed + k) # save plots if need be if plot: df = utils.Selector.read_csv(csvpath) if isinstance(plot, bool): plot = str(Path(csvpath).with_suffix("")) + "_plots" print(f"Saving plots into folder: {plot}") plotting.create_plots(df, output_folder=plot) if __name__ == "__main__": args = get_args() repeated_launch(args.experiment, num_workers=args.num_workers, cap_index=args.cap_index, output=args.output, seed=args.seed, plot=args.plot, imports=args.imports if args.imports is None else args.imports.split(","), repetitions=args.repetitions)