"""
Manage AWS Batch jobs, queues, and compute environments.
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import os, sys, argparse, base64, collections, io, subprocess, json, time, re, hashlib, itertools

from botocore.exceptions import ClientError

from . import config, logger
from .ls import register_parser, register_listing_parser
from .ecr import ecr_image_name_completer
from .ssh import ssh as aegea_ssh, ssh_parser as aegea_ssh_parser
from .util import Timestamp, paginate, get_mkfs_command, ThreadPoolExecutor
from .util.crypto import ensure_ssh_key
from .util.cloudinit import get_user_data
from .util.exceptions import AegeaException
from .util.printing import page_output, tabulate, YELLOW, RED, GREEN, BOLD, ENDC
from .util.aws import (resources, clients, ensure_iam_role, ensure_instance_profile, make_waiter, ensure_vpc,
                       ensure_security_group, ensure_log_group, IAMPolicyBuilder, resolve_ami, instance_type_completer,
                       expect_error_codes, instance_storage_shellcode, ARN, get_ssm_parameter)
from .util.aws.spot import SpotFleetBuilder
from .util.aws.logs import CloudwatchLogReader
from .util.aws.batch import ensure_job_definition, get_command_and_env, ensure_lambda_helper

def complete_queue_name(**kwargs):
    return [q["jobQueueName"] for q in paginate(clients.batch.get_paginator("describe_job_queues"))]

def complete_ce_name(**kwargs):
    return [c["computeEnvironmentName"] for c in paginate(clients.batch.get_paginator("describe_compute_environments"))]

def batch(args):
    batch_parser.print_help()

batch_parser = register_parser(batch, help="Manage AWS Batch resources", description=__doc__)

def queues(args):
    page_output(tabulate(paginate(clients.batch.get_paginator("describe_job_queues")), args))

parser = register_listing_parser(queues, parent=batch_parser, help="List Batch queues")

def create_queue(args):
    ces = [dict(computeEnvironment=e, order=i) for i, e in enumerate(args.compute_environments)]
    logger.info("Creating queue %s in %s", args.name, ces)
    queue = clients.batch.create_job_queue(jobQueueName=args.name, priority=args.priority, computeEnvironmentOrder=ces)
    make_waiter(clients.batch.describe_job_queues, "jobQueues[].status", "VALID", "pathAny").wait(jobQueues=[args.name])
    return queue

parser = register_parser(create_queue, parent=batch_parser, help="Create a Batch queue")
parser.add_argument("name")
parser.add_argument("--priority", type=int, default=5)
parser.add_argument("--compute-environments", nargs="+", required=True)

def delete_queue(args):
    clients.batch.update_job_queue(jobQueue=args.name, state="DISABLED")
    make_waiter(clients.batch.describe_job_queues, "jobQueues[].status", "VALID", "pathAny").wait(jobQueues=[args.name])
    clients.batch.delete_job_queue(jobQueue=args.name)

parser = register_parser(delete_queue, parent=batch_parser, help="Delete a Batch queue")
parser.add_argument("name").completer = complete_queue_name

def compute_environments(args):
    page_output(tabulate(paginate(clients.batch.get_paginator("describe_compute_environments")), args))

parser = register_listing_parser(compute_environments, parent=batch_parser, help="List Batch compute environments")

def ensure_launch_template(prefix=__name__.replace(".", "_"), **kwargs):
    name = prefix + "_" + hashlib.sha256(json.dumps(kwargs, sort_keys=True).encode()).hexdigest()[:32]
    try:
        clients.ec2.create_launch_template(LaunchTemplateName=name, LaunchTemplateData=kwargs)
    except ClientError as e:
        expect_error_codes(e, "InvalidLaunchTemplateName.AlreadyExistsException")
    return name

def create_compute_environment(args):
    commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n")
    user_data = get_user_data(commands=commands, mime_multipart_archive=True)
    if args.ecs_container_instance_ami:
        ecs_ami_id = args.ecs_container_instance_ami
    elif args.ecs_container_instance_ami_tags:
        ecs_ami_id = resolve_ami(**args.ecs_container_instance_ami_tags)
    else:
        ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id")
    launch_template = ensure_launch_template(ImageId=ecs_ami_id,
                                             # TODO: add configurable BDM for Docker image cache space
                                             UserData=base64.b64encode(user_data).decode())
    batch_iam_role = ensure_iam_role(args.service_role, trust=["batch"], policies=["service-role/AWSBatchServiceRole"])
    vpc = ensure_vpc()
    ssh_key_name = ensure_ssh_key(args.ssh_key_name, base_name=__name__)
    instance_profile = ensure_instance_profile(args.instance_role,
                                               policies={"service-role/AmazonAPIGatewayPushToCloudWatchLogs",
                                                         "service-role/AmazonEC2ContainerServiceforEC2Role",
                                                         "AmazonSSMManagedInstanceCore",
                                                         IAMPolicyBuilder(action="sts:AssumeRole", resource="*")})
    compute_resources = dict(type=args.compute_type,
                             minvCpus=args.min_vcpus, desiredvCpus=args.desired_vcpus, maxvCpus=args.max_vcpus,
                             instanceTypes=args.instance_types,
                             subnets=[subnet.id for subnet in vpc.subnets.all()],
                             securityGroupIds=[ensure_security_group("aegea.launch", vpc).id],
                             instanceRole=instance_profile.name,
                             bidPercentage=100,
                             spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name,
                             ec2KeyPair=ssh_key_name,
                             tags=dict(Name=__name__),
                             launchTemplate=dict(launchTemplateName=launch_template))
    logger.info("Creating compute environment %s in %s", args.name, vpc)
    compute_environment = clients.batch.create_compute_environment(computeEnvironmentName=args.name,
                                                                   type=args.type,
                                                                   computeResources=compute_resources,
                                                                   serviceRole=batch_iam_role.name)
    wtr = make_waiter(clients.batch.describe_compute_environments, "computeEnvironments[].status", "VALID", "pathAny",
                      delay=2, max_attempts=300)
    wtr.wait(computeEnvironments=[args.name])
    return compute_environment

cce_parser = register_parser(create_compute_environment, parent=batch_parser, help="Create a Batch compute environment")
cce_parser.add_argument("name")
cce_parser.add_argument("--type", choices={"MANAGED", "UNMANAGED"})
cce_parser.add_argument("--compute-type", choices={"EC2", "SPOT"})
cce_parser.add_argument("--min-vcpus", type=int)
cce_parser.add_argument("--desired-vcpus", type=int)
cce_parser.add_argument("--max-vcpus", type=int)
cce_parser.add_argument("--instance-types", nargs="+").completer = instance_type_completer
cce_parser.add_argument("--ssh-key-name")
cce_parser.add_argument("--instance-role", default=__name__ + ".ecs_container_instance")
cce_parser.add_argument("--service-role", default=__name__ + ".service")
cce_parser.add_argument("--ecs-container-instance-ami")
cce_parser.add_argument("--ecs-container-instance-ami-tags")

def update_compute_environment(args):
    update_compute_environment_args = dict(computeEnvironment=args.name, computeResources={})
    if args.min_vcpus is not None:
        update_compute_environment_args["computeResources"].update(minvCpus=args.min_vcpus)
    if args.desired_vcpus is not None:
        update_compute_environment_args["computeResources"].update(desiredvCpus=args.desired_vcpus)
    if args.max_vcpus is not None:
        update_compute_environment_args["computeResources"].update(maxvCpus=args.max_vcpus)
    return clients.batch.update_compute_environment(**update_compute_environment_args)

uce_parser = register_parser(update_compute_environment, parent=batch_parser, help="Update a Batch compute environment")
uce_parser.add_argument("name").completer = complete_ce_name
uce_parser.add_argument("--min-vcpus", type=int)
uce_parser.add_argument("--desired-vcpus", type=int)
uce_parser.add_argument("--max-vcpus", type=int)

def delete_compute_environment(args):
    clients.batch.update_compute_environment(computeEnvironment=args.name, state="DISABLED")
    wtr = make_waiter(clients.batch.describe_compute_environments, "computeEnvironments[].status", "VALID", "pathAny")
    wtr.wait(computeEnvironments=[args.name])
    clients.batch.delete_compute_environment(computeEnvironment=args.name)

parser = register_parser(delete_compute_environment, parent=batch_parser, help="Delete a Batch compute environment")
parser.add_argument("name").completer = complete_ce_name

def ensure_queue(name):
    cq_args = argparse.Namespace(name=name, priority=5, compute_environments=[name])
    try:
        return create_queue(cq_args)
    except ClientError:
        create_compute_environment(cce_parser.parse_args(args=[name]))
        return create_queue(cq_args)

def submit(args):
    try:
        ensure_lambda_helper()
    except Exception as e:
        logger.error("Failed to install Lambda helper:")
        logger.error("%s: %s", type(e).__name__, e)
        logger.error("Aegea will be unable to look up logs for old Batch jobs.")
    if args.staging_s3_bucket is None:
        args.staging_s3_bucket = "aegea-batch-jobs-" + ARN.get_account_id()
    if args.job_definition_arn is None:
        if not any([args.command, args.execute, args.wdl]):
            raise AegeaException("One of the arguments --command --execute --wdl is required")
    elif args.name is None:
        raise AegeaException("The argument --name is required")
    ensure_log_group("docker")
    ensure_log_group("syslog")
    if args.job_definition_arn is None:
        command, environment = get_command_and_env(args)
        container_overrides = dict(command=command, environment=environment)

        if args.job_role == config.batch_submit.job_role:
            args.default_job_role_iam_policies.append(IAMPolicyBuilder(
                action=["s3:List*", "s3:HeadObject*", "s3:GetObject*", "s3:PutObject*"],
                resource=["arn:aws:s3:::aegea-*", "arn:aws:s3:::aegea-*/*"]
            ))
        else:
            args.default_job_role_iam_policies = []

        jd_res = ensure_job_definition(args)
        args.job_definition_arn = jd_res["jobDefinitionArn"]
        args.name = args.name or "{}_{}".format(jd_res["jobDefinitionName"], jd_res["revision"])
    else:
        container_overrides = {}
        if args.command:
            container_overrides["command"] = args.command
        if args.environment:
            container_overrides["environment"] = args.environment
    if args.memory is None:
        logger.warn("Specify a memory quota for your job with --memory-mb NNNN.")
        logger.warn("The memory quota is required and a hard limit. Setting it to %d MB.", int(args.default_memory_mb))
        args.memory = int(args.default_memory_mb)
    container_overrides["memory"] = args.memory
    submit_args = dict(jobName=args.name,
                       jobQueue=args.queue,
                       dependsOn=[dict(jobId=dep) for dep in args.depends_on],
                       jobDefinition=args.job_definition_arn,
                       parameters={k: v for k, v in args.parameters},
                       containerOverrides=container_overrides)
    if args.dry_run:
        logger.info("The following command would be run:")
        sys.stderr.write(json.dumps(submit_args, indent=4) + "\n")
        return {"Dry run succeeded": True}
    try:
        job = clients.batch.submit_job(**submit_args)
    except ClientError as e:
        if not re.search("JobQueue .+ not found", str(e)):
            raise
        ensure_queue(args.queue)
        job = clients.batch.submit_job(**submit_args)
    if args.watch:
        try:
            watch(watch_parser.parse_args([job["jobId"]]))
            if args.wdl:
                wdl_output = resources.s3.Bucket(args.staging_s3_bucket).Object("wdl_output/{jobId}.json".format(**job))
                return json.loads(wdl_output.get()["Body"].read())
        except KeyboardInterrupt:
            logger.critical("Interrupt received for Batch job %s.", job["jobId"])
            logger.critical("Press Enter to terminate job. Press Ctrl-C to quit.")
            input()
            clients.batch.terminate_job(jobId=job["jobId"], reason="Terminated by aegea.batch from user interrupt")
            return SystemExit("Sent termination request for Batch job {}".format(job["jobId"]))
    elif args.wait:
        raise NotImplementedError()
    return job

submit_parser = register_parser(submit, parent=batch_parser, help="Submit a job to a Batch queue")
submit_parser.add_argument("--name")
submit_parser.add_argument("--queue", default=__name__.replace(".", "_")).completer = complete_queue_name
submit_parser.add_argument("--depends-on", nargs="+", metavar="JOB_ID", default=[])
submit_parser.add_argument("--job-definition-arn")

def add_command_args(parser):
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--watch", action="store_true", help="Monitor submitted job, stream log until job completes")
    group.add_argument("--wait", action="store_true",
                       help="Block on job. Exit with code 0 if job succeeded, 1 if failed")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--command", nargs="+", help="Run these commands as the job (using " + BOLD("bash -c") + ")")
    group.add_argument("--execute", type=argparse.FileType("rb"), metavar="EXECUTABLE",
                       help="Read this executable file and run it as the job")
    group.add_argument("--wdl", type=argparse.FileType("rb"), metavar="WDL_WORKFLOW",
                       help="Read this WDL workflow file and run it as the job")
    parser.add_argument("--wdl-input", type=argparse.FileType("r"), metavar="WDL_INPUT_JSON", default=sys.stdin,
                        help="With --wdl, use this JSON file as the WDL job input (default: stdin)")
    parser.add_argument("--environment", nargs="+", metavar="NAME=VALUE",
                        type=lambda x: dict(zip(["name", "value"], x.split("=", 1))), default=[])
    parser.add_argument("--staging-s3-bucket", help=argparse.SUPPRESS)

def add_job_defn_args(parser):
    parser.add_argument("--ulimits", nargs="*",
                        help="Separate ulimit name and value with colon, for example: --ulimits nofile:20000",
                        default=["nofile:100000"])
    img_group = parser.add_mutually_exclusive_group()
    img_group.add_argument("--image", default="ubuntu", metavar="DOCKER_IMAGE",
                           help="Docker image URL to use for running job/task")
    ecs_img_help = "Name of Docker image residing in this account's Elastic Container Registry"
    ecs_img_arg = img_group.add_argument("--ecs-image", "--ecr-image", "-i", metavar="REPO[:TAG]", help=ecs_img_help)
    ecs_img_arg.completer = ecr_image_name_completer
    parser.add_argument("--volumes", nargs="+", metavar="HOST_PATH=GUEST_PATH", type=lambda x: x.split("=", 1),
                        default=[])
    parser.add_argument("--memory-mb", dest="memory", type=int)

add_command_args(submit_parser)

group = submit_parser.add_argument_group(title="job definition parameters", description="""
See http://docs.aws.amazon.com/batch/latest/userguide/job_definitions.html""")
add_job_defn_args(group)
group.add_argument("--vcpus", type=int, default=1)
group.add_argument("--gpus", type=int, default=0)
group.add_argument("--privileged", action="store_true", default=False)
group.add_argument("--volume-type", choices={"standard", "io1", "gp2", "sc1", "st1"},
                   help="io1, PIOPS SSD; gp2, general purpose SSD; sc1, cold HDD; st1, throughput optimized HDD")
group.add_argument("--parameters", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), default=[])
group.add_argument("--job-role", metavar="IAM_ROLE", help="Name of IAM role to grant to the job")
group.add_argument("--storage", nargs="+", metavar="MOUNTPOINT=SIZE_GB",
                   type=lambda x: x.rstrip("GBgb").split("=", 1), default=[])
group.add_argument("--efs-storage", action="store", dest="efs_storage", default=False,
                   help="Mount EFS network filesystem to the mount point specified. Example: --efs-storage /mnt")
group.add_argument("--mount-instance-storage", nargs="?", const="/mnt",
                   help="Assemble (MD RAID0), format and mount ephemeral instance storage on this mount point")
submit_parser.add_argument("--timeout",
                           help="Terminate (and possibly restart) the job after this time (use suffix s, m, h, d, w)")
submit_parser.add_argument("--retry-attempts", type=int, default=1,
                           help="Number of times to restart the job upon failure")
submit_parser.add_argument("--dry-run", action="store_true", help="Gather arguments and stop short of submitting job")

def terminate(args):
    def terminate_one(job_id):
        return clients.batch.terminate_job(jobId=job_id, reason=args.reason)

    result = [terminate_one(args.job_id[0])]

    if len(args.job_id) > 1:
        with ThreadPoolExecutor() as executor:
            result += list(executor.map(terminate_one, args.job_id[1:]))
    logger.info("Sent termination requests for %d jobs", len(result))

parser = register_parser(terminate, parent=batch_parser, help="Terminate Batch jobs")
parser.add_argument("job_id", nargs="+")
parser.add_argument("--reason", help="A message to attach to the job that explains the reason for canceling it")

def ls(args, page_size=100):
    queues = args.queues or [q["jobQueueName"] for q in clients.batch.describe_job_queues()["jobQueues"]]

    def list_jobs_worker(list_jobs_worker_args):
        queue, status = list_jobs_worker_args
        return [j["jobId"] for j in clients.batch.list_jobs(jobQueue=queue, jobStatus=status)["jobSummaryList"]]

    with ThreadPoolExecutor() as executor:
        job_ids = sum(executor.map(list_jobs_worker, itertools.product(queues, args.status)), [])

        def describe_jobs_worker(start_index):
            return clients.batch.describe_jobs(jobs=job_ids[start_index:start_index + page_size])["jobs"]

        table = sum(executor.map(describe_jobs_worker, range(0, len(job_ids), page_size)), [])
    page_output(tabulate(table, args, cell_transforms={"createdAt": Timestamp}))

job_status_colors = dict(SUBMITTED=YELLOW(), PENDING=YELLOW(), RUNNABLE=BOLD() + YELLOW(),
                         STARTING=GREEN(), RUNNING=GREEN(),
                         SUCCEEDED=BOLD() + GREEN(), FAILED=BOLD() + RED())
job_states = job_status_colors.keys()
parser = register_listing_parser(ls, parent=batch_parser, help="List Batch jobs")
parser.add_argument("--queues", nargs="+").completer = complete_queue_name
parser.add_argument("--status", nargs="+", default=job_states, choices=job_states)

def get_job_desc(job_id):
    try:
        return clients.batch.describe_jobs(jobs=[job_id])["jobs"][0]
    except IndexError:
        bucket = resources.s3.Bucket("aegea-batch-jobs-{}".format(ARN.get_account_id()))
        return json.loads(bucket.Object("job_descriptions/{}".format(job_id)).get()["Body"].read())

def describe(args):
    return get_job_desc(args.job_id)

parser = register_parser(describe, parent=batch_parser, help="Describe a Batch job")
parser.add_argument("job_id")

def format_job_status(status):
    return job_status_colors[status] + status + ENDC()

def print_event(event):
    print(str(Timestamp(event["timestamp"])) + " " + event["message"])

def get_logs(args, print_event_fn=print_event):
    for event in CloudwatchLogReader(args.log_stream_name, head=args.head, tail=args.tail):
        print_event_fn(event)

def watch(args, print_event_fn=print_event):
    job_desc = get_job_desc(args.job_id)
    args.job_name = job_desc["jobName"]
    logger.info("Watching job %s (%s)", args.job_id, args.job_name)
    last_status, log_reader = None, None
    while last_status not in {"SUCCEEDED", "FAILED"}:
        job_desc = get_job_desc(args.job_id)
        if job_desc["status"] != last_status:
            logger.info("Job %s %s", args.job_id, format_job_status(job_desc["status"]))
            last_status = job_desc["status"]
            if job_desc["status"] in {"RUNNING", "SUCCEEDED", "FAILED"}:
                logger.info("Job %s log stream: %s", args.job_id, job_desc.get("container", {}).get("logStreamName"))
        if job_desc["status"] in {"RUNNING", "SUCCEEDED", "FAILED"}:
            if "logStreamName" in job_desc.get("container", {}):
                args.log_stream_name = job_desc["container"]["logStreamName"]
                if log_reader is None:
                    log_reader = CloudwatchLogReader(args.log_stream_name, head=args.head, tail=args.tail)
                for event in log_reader:
                    print_event_fn(event)
        if "statusReason" in job_desc:
            logger.info("Job %s: %s", args.job_id, job_desc["statusReason"])
        if job_desc.get("container", {}).get("exitCode"):
            return SystemExit(job_desc["container"]["exitCode"])
        time.sleep(1)

get_logs_parser = register_parser(get_logs, parent=batch_parser, help="Retrieve logs for a Batch job")
get_logs_parser.add_argument("log_stream_name")
watch_parser = register_parser(watch, parent=batch_parser, help="Monitor a running Batch job and stream its logs")
watch_parser.add_argument("job_id")
for parser in get_logs_parser, watch_parser:
    lines_group = parser.add_mutually_exclusive_group()
    lines_group.add_argument("--head", type=int, nargs="?", const=10,
                             help="Retrieve this number of lines from the beginning of the log (default 10)")
    lines_group.add_argument("--tail", type=int, nargs="?", const=10,
                             help="Retrieve this number of lines from the end of the log (default 10)")

def ssh(args):
    if not args.ssh_args:
        args.ssh_args = ["/bin/bash", "-l"]
    job_desc = clients.batch.describe_jobs(jobs=[args.job_id])["jobs"][0]
    logger.info("Job %s %s", args.job_id, format_job_status(job_desc["status"]))
    job_queue_desc = clients.batch.describe_job_queues(jobQueues=[job_desc["jobQueue"]])["jobQueues"][0]
    ce = job_queue_desc["computeEnvironmentOrder"][0]["computeEnvironment"]
    ce_desc = clients.batch.describe_compute_environments(computeEnvironments=[ce])["computeEnvironments"][0]
    if "containerInstanceArn" not in job_desc["container"]:
        raise AegeaException("Job {} has not been dispatched to a container instance".format(args.job_id))
    ecs_ci_arn = job_desc["container"]["containerInstanceArn"]
    ecs_ci_desc = clients.ecs.describe_container_instances(cluster=ce_desc["ecsClusterArn"],
                                                           containerInstances=[ecs_ci_arn])["containerInstances"][0]
    ecs_ci_ec2_id = ecs_ci_desc["ec2InstanceId"]
    logger.info("Job {} is on EC2 instance {}".format(args.job_id, ecs_ci_ec2_id))
    ecs_task_arn = job_desc["container"]["taskArn"]
    res = clients.ecs.describe_tasks(cluster=ce_desc["ecsClusterArn"], tasks=[ecs_task_arn])
    if len(res["tasks"]) == 0:
        raise AegeaException("No ECS task found for job {}".format(args.job_id))
    container_id = res["tasks"][0]["containers"][0]["runtimeId"]
    logger.info("Job {} is in container {}".format(args.job_id, container_id))
    aegea_ssh(aegea_ssh_parser.parse_args(["-t", "-l", "ec2-user", ecs_ci_ec2_id,
                                           "docker", "exec", "--interactive", "--tty", container_id] + args.ssh_args))

ssh_parser = register_parser(ssh, parent=batch_parser, help="Log in to a running Batch job via SSH")
ssh_parser.add_argument("job_id")
ssh_parser.add_argument("ssh_args", nargs=argparse.REMAINDER)