#!/usr/bin/env python3

import os
from urllib import parse
import subprocess
from datetime import datetime
import boto3
import click


PRE_PROCESS_TARGETS = {
    'data/external/wikipedia/cache_SUCCESS',
    'data/external/wikipedia/dump_redirects.pickle',
    'data/external/wikipedia/pages',
    'data/external/wikipedia/parsed-wiki',
    'data/external/wikipedia/parsed-wiki_SUCCESS',
    'data/external/wikipedia/wikipedia-titles.pickle',
    'data/external/wikipedia/all_wiki_redirects.csv',
    'data/external/wikidata_instance-of.pickle'
}


GUESS_TARGETS = {
    'output/guesser'
}

VW_INPUT = {'output/vw_input'}

VW_MODELS = {'output/models'}

PREDICTIONS = {'output/predictions'}

SUMMARIES = {'output/summary'}

REPORTING = {'output/reporting'}

EXPO = {'output/expo'}


CHECKPOINT_TARGETS = (
    PRE_PROCESS_TARGETS | GUESS_TARGETS |
    VW_INPUT | VW_MODELS | PREDICTIONS | SUMMARIES | REPORTING | EXPO
)

TARGET_GROUPS = {
    'preprocess': PRE_PROCESS_TARGETS,
    'guesser': GUESS_TARGETS,
    'vw_input': VW_INPUT,
    'vw_models': VW_MODELS,
    'predictions': PREDICTIONS,
    'summaries': SUMMARIES,
    'reporting': REPORTING,
    'expo': EXPO,
    'all': CHECKPOINT_TARGETS
}


CHECKPOINT_CHOICES = set(TARGET_GROUPS.keys()) | CHECKPOINT_TARGETS


class S3:
    def __init__(self, bucket, namespace):
        self.s3 = boto3.resource('s3')
        self.bucket = bucket
        self.namespace = namespace

    def list_runs(self):
        response = self.s3.meta.client.list_objects_v2(
            Bucket=self.bucket,
            Prefix=self.namespace + '/',
            Delimiter='/'
        )
        for f in response['CommonPrefixes']:
            yield f['Prefix'].split('/')[1]

    def create_run(self, date):
        if not os.path.exists('/tmp/qb'):
            os.makedirs('/tmp/qb')

        with open('/tmp/qb/run_id', 'w') as f:
            f.write(date)

        self.s3.meta.client.upload_file(
            '/tmp/qb/run_id',
            self.bucket,
            '{}/{}/run_id'.format(self.namespace, date)
        )

    def latest_run(self):
        all_runs = [datetime.strptime(date, '%Y-%m-%d') for date in self.list_runs()]
        if len(all_runs) == 0:
            raise ValueError('There are no runs so therefore there is no latest run')
        latest_id = max(all_runs)
        return latest_id.strftime('%Y-%m-%d')


def fetch(variable, environment_variable):
    if variable is None:
        env_variable = os.environ.get(environment_variable)
        if env_variable is not None and env_variable != "":
            return env_variable
        else:
            raise ValueError('You must set {} or pass the variable as an option'.format(
                environment_variable))
    else:
        return variable


def shell(command):
    return subprocess.run(command, check=True, shell=True)


def compile_targets(targets):
    compiled_targets = set()
    for t in targets:
        if t in TARGET_GROUPS:
            compiled_targets |= TARGET_GROUPS[t]
        else:
            compiled_targets |= {t}

    return compiled_targets


@click.group()
@click.option('--bucket', help='AWS S3 bucket to checkpoint and restore from')
@click.option('--namespace', help='Namespace within bucket to checkpoint and restore from')
@click.pass_context
def cli(ctx, bucket, namespace):
    if not os.path.exists('/tmp/qb'):
        os.makedirs('/tmp/qb')

    ctx.obj['s3'] = S3(
        fetch(bucket, 'QB_AWS_S3_BUCKET'),
        fetch(namespace, 'QB_AWS_S3_NAMESPACE')
    )


@cli.command(name='list')
@click.pass_context
def list_runs(ctx):
    for key in sorted(ctx.obj['s3'].list_runs()):
        print(key)


@cli.command()
@click.pass_context
def latest(ctx):
    print(ctx.obj['s3'].latest_run())


@cli.command()
@click.pass_context
def keys(ctx):
    for k in sorted(TARGET_GROUPS):
        print(k)


@cli.command()
@click.option('--date', help='Date to use for run identifier in YYYY-MM-DD format')
@click.pass_context
def create(ctx, date):
    if date is None:
        date = datetime.now().strftime('%Y-%m-%d')
    ctx.obj['s3'].create_run(date)


@cli.command()
@click.option('--date', help="Which date to save the qanta run to, by default the most recent")
@click.argument('targets', nargs=-1, type=click.Choice(CHECKPOINT_CHOICES), required=True)
@click.pass_context
def save(ctx, date, targets):
    s3 = ctx.obj['s3']
    if date is None:
        date = s3.latest_run()

    for t in compile_targets(targets):
        name = parse.quote_plus(t)
        shell('tar cvf - {target} | lz4 > /tmp/qb/{name}.tar.lz4'.format(target=t, name=name))
        shell('aws s3 cp /tmp/qb/{name}.tar.lz4 s3://{bucket}/{namespace}/{date}/{name}'.format(
            name=name,
            bucket=s3.bucket,
            namespace=s3.namespace,
            date=date
        ))
        shell('rm /tmp/qb/{name}.tar.lz4'.format(name=name))


@cli.command()
@click.option('--date', help="Which date to restore the qanta run from")
@click.argument('targets', nargs=-1, type=click.Choice(CHECKPOINT_CHOICES), required=True)
@click.pass_context
def restore(ctx, date, targets):
    s3 = ctx.obj['s3']
    if date is None:
        date = s3.latest_run()

    for t in compile_targets(targets):
        name = parse.quote_plus(t)
        shell('aws s3 cp s3://{bucket}/{namespace}/{date}/{name} /tmp/qb/{name}.tar.lz4'.format(
            name=name,
            bucket=s3.bucket,
            namespace=s3.namespace,
            date=date
        ))
        shell('lz4 -d /tmp/qb/{name}.tar.lz4 | tar -x -C .'.format(name=name))
        shell('rm /tmp/qb/{name}.tar.lz4'.format(name=name))


if __name__ == '__main__':
    cli(obj={})