import argparse
import json
import logging
import sys
from time import time

from igraph import Graph
from modelforge.logs import setup_logging
from sourced.ml import extractors
from sourced.ml.utils import add_engine_args, add_spark_args
from sourced.ml.cmd import ArgumentDefaultsHelpFormatterNoNone
from sourced.ml.cmd.args import add_bow_args, add_feature_args, add_repo2_args, \
    add_df_args, add_repartitioner_arg

from apollo.bags import preprocess, source2bags
from apollo.cassandra_utils import reset_db
from apollo.graph import find_connected_components, dumpcc, detect_communities, dumpcmd, \
    evaluate_communities
from apollo.hasher import hash_batches
from apollo.query import query
from apollo.warmup import warmup


CASSANDRA_PACKAGE = "com.datastax.spark:spark-cassandra-connector_2.11:2.0.3"


def get_parser() -> argparse.ArgumentParser:
    """
    Create the cmdline argument parser.
    """
    parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatterNoNone)
    parser.add_argument("--log-level", default="INFO", choices=logging._nameToLevel,
                        help="Logging verbosity.")

    def add_feature_weight_arg(my_parser):
        help_desc = "%s's weight - all features from this extractor will be multiplied by this " \
                  "factor"
        for ex in extractors.__extractors__.values():
            my_parser.add_argument("--%s-weight" % ex.NAME, default=1, type=float,
                                   help=help_desc % ex.__name__)

    def add_cassandra_args(my_parser):
        my_parser.add_argument(
            "--cassandra", default="0.0.0.0:9042", help="Cassandra's host:port.")
        my_parser.add_argument("--keyspace", default="apollo",
                               help="Cassandra's key space.")
        my_parser.add_argument(
            "--tables", help="Table name mapping (JSON): bags, hashes, hashtables, hashtables2.")

    def add_wmh_args(my_parser, params_help: str, add_hash_size: bool, required: bool):
        if add_hash_size:
            my_parser.add_argument("--size", type=int, default=128, help="Hash size.")
        my_parser.add_argument("-p", "--params", required=required, help=params_help)
        my_parser.add_argument("-t", "--threshold", required=required, type=float,
                               help="Jaccard similarity threshold.")
        my_parser.add_argument("--false-positive-weight", type=float, default=0.5,
                               help="Used to adjust the relative importance of "
                                    "minimizing false positives count when optimizing "
                                    "for the Jaccard similarity threshold.")
        my_parser.add_argument("--false-negative-weight", type=float, default=0.5,
                               help="Used to adjust the relative importance of "
                                    "minimizing false negatives count when optimizing "
                                    "for the Jaccard similarity threshold.")

    def add_template_args(my_parser, default_template):
        my_parser.add_argument("--batch", type=int, default=100,
                               help="Number of hashes to query at a time.")
        my_parser.add_argument("--template", default=default_template,
                               help="Jinja2 template to render.")

    # Create and construct subparsers
    subparsers = parser.add_subparsers(help="Commands", dest="command")

    # ------------------------------------------------------------------------
    warmup_parser = subparsers.add_parser(
        "warmup", help="Initialize source{d} engine.")
    warmup_parser.set_defaults(handler=warmup)
    add_engine_args(warmup_parser, default_packages=[CASSANDRA_PACKAGE])

    # ------------------------------------------------------------------------
    db_parser = subparsers.add_parser("resetdb", help="Destructively initialize the database.")
    db_parser.set_defaults(handler=reset_db)
    add_cassandra_args(db_parser)
    db_parser.add_argument(
        "--hashes-only", action="store_true",
        help="Only clear the tables: hashes, hashtables, hashtables2. Do not touch the rest.")
    # ------------------------------------------------------------------------
    preprocess_parser = subparsers.add_parser(
        "preprocess", help="Creates the index, quant and docfreq model of the bag-of-words model.")
    preprocess_parser.set_defaults(handler=preprocess)
    add_df_args(preprocess_parser)
    add_repo2_args(preprocess_parser)
    add_feature_args(preprocess_parser)
    add_repartitioner_arg(preprocess_parser)
    preprocess_parser.add_argument(
        "--cached-index-path", default=None,
        help="[OUT] Path to the docfreq model holding the document's index.")
    # ------------------------------------------------------------------------
    source2bags_parser = subparsers.add_parser(
        "bags", help="Convert source code to weighted sets.")
    source2bags_parser.set_defaults(handler=source2bags)
    add_bow_args(source2bags_parser)
    add_repo2_args(source2bags_parser, default_packages=[CASSANDRA_PACKAGE])
    add_feature_args(source2bags_parser)
    add_cassandra_args(source2bags_parser)
    add_df_args(source2bags_parser)
    add_repartitioner_arg(source2bags_parser)
    source2bags_parser.add_argument(
        "--cached-index-path", default=None,
        help="[IN] Path to the docfreq model holding the document's index.")

    # ------------------------------------------------------------------------
    hash_parser = subparsers.add_parser(
        "hash", help="Run MinHashCUDA on the bag batches.")
    hash_parser.set_defaults(handler=hash_batches)
    hash_parser.add_argument("-i", "--input",
                             help="Path to the directory with Parquet files.")
    hash_parser.add_argument("--seed", type=int, default=int(time()),
                             help="Random generator's seed.")
    hash_parser.add_argument("--mhc-verbosity", type=int, default=1,
                             help="MinHashCUDA logs verbosity level.")
    hash_parser.add_argument("--devices", type=int, default=0,
                             help="Or-red indices of NVIDIA devices to use. 0 means all.")
    add_wmh_args(hash_parser, "Path to the output file with WMH parameters.", True, True)
    add_cassandra_args(hash_parser)
    add_spark_args(hash_parser, default_packages=[CASSANDRA_PACKAGE])
    add_feature_weight_arg(hash_parser)
    add_repartitioner_arg(hash_parser)

    # ------------------------------------------------------------------------
    query_parser = subparsers.add_parser("query", help="Query for similar files.")
    query_parser.set_defaults(handler=query)
    mode_group = query_parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument("-i", "--id", help="Query for this id (id mode).")
    mode_group.add_argument("-c", "--file", help="Query for this file (file mode).")
    query_parser.add_argument("--docfreq", help="Path to OrderedDocumentFrequencies (file mode).")
    query_parser.add_argument("--min-docfreq", default=1, type=int,
                              help="The minimum document frequency of each feature.")
    query_parser.add_argument(
        "--bblfsh", default="localhost:9432", help="Babelfish server's address.")
    query_parser.add_argument("--precise", action="store_true",
                              help="Calculate the precise set.")
    add_wmh_args(query_parser, "Path to the Weighted MinHash parameters.", False, False)
    add_feature_args(query_parser, required=False)
    add_template_args(query_parser, "query.md.jinja2")
    add_cassandra_args(query_parser)

    # ------------------------------------------------------------------------
    cc_parser = subparsers.add_parser(
        "cc", help="Load the similar pairs of files and run connected components analysis.")
    cc_parser.set_defaults(handler=find_connected_components)
    add_cassandra_args(cc_parser)
    cc_parser.add_argument("-o", "--output", required=True,
                           help="[OUT] Path to connected components ASDF model.")

    # ------------------------------------------------------------------------
    dumpcc_parser = subparsers.add_parser(
        "dumpcc", help="Output the connected components to stdout.")
    dumpcc_parser.set_defaults(handler=dumpcc)
    dumpcc_parser.add_argument("-i", "--input", required=True,
                               help="Path to connected components ASDF model.")
    # ------------------------------------------------------------------------
    community_parser = subparsers.add_parser(
        "cmd", help="Run Community Detection analysis on the connected components from \"cc\".")
    community_parser.set_defaults(handler=detect_communities)
    community_parser.add_argument("-i", "--input", required=True,
                                  help="Path to connected components ASDF model.")
    community_parser.add_argument("-o", "--output", required=True,
                                  help="[OUT] Path to the communities ASDF model.")
    community_parser.add_argument("--edges", choices=("linear", "quadratic", "1", "2"),
                                  default="linear",
                                  help="The method to generate the graph's edges: bipartite - "
                                       "linear and fast, but may not fit some the CD algorithms, "
                                       "or all to all within a bucket - quadratic and slow, but "
                                       "surely fits all the algorithms.")
    cmd_choices = [k[10:] for k in dir(Graph) if k.startswith("community_")]
    community_parser.add_argument("-a", "--algorithm", choices=cmd_choices,
                                  default="walktrap",
                                  help="The community detection algorithm to apply.")
    community_parser.add_argument("-p", "--params", type=json.loads, default={},
                                  help="Parameters for the algorithm (**kwargs, JSON format).")
    community_parser.add_argument("--no-spark", action="store_true", help="Do not use Spark.")
    add_spark_args(community_parser)

    # ------------------------------------------------------------------------
    dumpcmd_parser = subparsers.add_parser(
        "dumpcmd", help="Output the detected communities to stdout.")
    dumpcmd_parser.set_defaults(handler=dumpcmd)
    dumpcmd_parser.add_argument("input", help="Path to the communities ASDF model.")
    add_template_args(dumpcmd_parser, "report.md.jinja2")
    add_cassandra_args(dumpcmd_parser)

    # ------------------------------------------------------------------------
    evalcc_parser = subparsers.add_parser(
        "evalcc", help="Evaluate the communities: calculate the precise similarity and the "
                       "fitness metric.")
    evalcc_parser.set_defaults(handler=evaluate_communities)
    evalcc_parser.add_argument("-t", "--threshold", required=True, type=float,
                               help="Jaccard similarity threshold.")
    evalcc_parser.add_argument("-i", "--input", required=True,
                               help="Path to the communities model.")

    add_spark_args(evalcc_parser, default_packages=[CASSANDRA_PACKAGE])
    add_cassandra_args(evalcc_parser)

    # TODO: retable [.....] -> [.] [.] [.] [.] [.]
    return parser


def main():
    """
    Creates all the argument parsers and invokes the function from set_defaults().

    :return: The result of the function from set_defaults().
    """
    parser = get_parser()
    args = parser.parse_args()
    args.log_level = logging._nameToLevel[args.log_level]
    setup_logging(args.log_level)
    try:
        handler = args.handler
    except AttributeError:
        def print_usage(_):
            parser.print_usage()

        handler = print_usage
    return handler(args)


if __name__ == "__main__":
    sys.exit(main())