import multiprocessing as mp
import argparse
import os
import yaml
from packaging import version

from utils import dist_init
from trainer import Trainer


def main(args):
    with open(args.config) as f:
        if version.parse(yaml.version >= "5.1"):
            config = yaml.load(f, Loader=yaml.FullLoader)
        else:
            config = yaml.load(f)

    for k, v in config.items():
        setattr(args, k, v)

    # exp path
    if not hasattr(args, 'exp_path'):
        args.exp_path = os.path.dirname(args.config)

    # dist init
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    dist_init(args.launcher, backend='nccl')

    # train
    trainer = Trainer(args)
    trainer.run()


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PyTorch Kinematics')
    parser.add_argument('--config', required=True, type=str)
    parser.add_argument('--launcher', default='pytorch', type=str)
    parser.add_argument('--load-iter', default=None, type=int)
    parser.add_argument('--load-path', default=None, type=str)
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--validate', action='store_true')
    parser.add_argument('--extract', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    main(args)