import argparse
from test import test
from environment import Environment


def parse():
    parser = argparse.ArgumentParser(description="MLDS&ADL HW3")
    parser.add_argument('--env_name', default=None, help='environment name')
    parser.add_argument('--train_pg', action='store_true', help='whether train policy gradient')
    parser.add_argument('--test_pg', action='store_true', help='whether test policy gradient')
    parser.add_argument('--train_ac', action='store_true', help='wheher train Actor Critic')
    parser.add_argument('--train_pgc', action='store_true', help='wheher train PG on cart')
    parser.add_argument('--video_dir', default=None, help='output video directory')
    parser.add_argument('--do_render', action='store_true', help='whether render environment')
    parser.add_argument('--save_summary_path', type=str, default = "pg_summary/", help='')
    parser.add_argument('--save_network_path', type=str, default = "saved_pg_networks/", help='')
    try:
        from argument import add_arguments
        parser = add_arguments(parser)
    except:
        pass
    args = parser.parse_args()
    return args


def run(args):
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    # Experiment on Cartpole only, test unsupported
    if args.train_ac:
        env_name = args.env_name or 'CartPole-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_actorcritic import Agent_ActorCritic
        agent = Agent_ActorCritic(env, args)
        agent.train()
    if args.train_pgc:
        env_name = args.env_name or 'CartPole-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg_cart import Agent_PGC
        agent = Agent_PGC(env, args)
        agent.train()


if __name__ == '__main__':
    args = parse()
    run(args)