import re
import os
import shutil
import textwrap
import argparse
import ujson as json

from pyspark import SparkContext, SparkConf
from sift.format import ModelFormat

import logging
log = logging.getLogger() 

class DatasetBuilder(object):
    """ Wrapper for modules which extract models of entities or text from a corpus of linked documents """
    def __init__(self, **kwargs):
        self.output_path = kwargs.pop('output_path')
        self.sample = kwargs.pop('sample')

        fmtcls = kwargs.pop('fmtcls')
        fmt_args = {p:kwargs[p] for p in fmtcls.__init__.__code__.co_varnames if p in kwargs}
        self.formatter = fmtcls(**fmt_args)

        modelcls = kwargs.pop('modelcls')
        self.model_name = re.sub('([A-Z])', r' \1', modelcls.__name__).strip()

        log.info("Building %s...", self.model_name)
        self.model = modelcls(**kwargs)

    def __call__(self):
        c = SparkConf().setAppName('Build %s' % self.model_name)

        log.info('Using spark master: %s', c.get('spark.master'))
        sc = SparkContext(conf=c)

        kwargs = self.model.prepare(sc)
        m = self.model.build(**kwargs)
        m = self.model.format_items(m)
        m = self.formatter(m)

        if self.output_path:
            log.info("Saving to: %s", self.output_path)
            if os.path.isdir(self.output_path):
                log.warn('Writing over output path: %s', self.output_path)
                shutil.rmtree(self.output_path)
            m.saveAsTextFile(self.output_path, 'org.apache.hadoop.io.compress.GzipCodec')
        elif self.sample > 0:
            print '\n'.join(str(i) for i in m.take(self.sample))

        log.info('Done.')

    @classmethod
    def add_arguments(cls, p):
        p.add_argument('--save', dest='output_path', required=False, default=None, metavar='OUTPUT_PATH')
        p.add_argument('--sample', dest='sample', required=False, default=1, type=int, metavar='NUM_SAMPLES')
        p.set_defaults(cls=cls)

        sp = p.add_subparsers()
        for modelcls in cls.providers():
            name = modelcls.__name__
            help_str = modelcls.__doc__.split('\n')[0]
            desc = textwrap.dedent(modelcls.__doc__.rstrip())
            csp = sp.add_parser(name,
                                help=help_str,
                                description=desc,
                                formatter_class=argparse.RawDescriptionHelpFormatter)
            modelcls.add_arguments(csp)
            cls.add_formatter_arguments(csp)

        return p

    @classmethod
    def add_formatter_arguments(cls, p):
        sp = p.add_subparsers()
        for fmtcls in ModelFormat.iter_options():
            name = fmtcls.__name__.lower()
            if name.endswith('format'):
                name = name[:-len('format')]
            help_str = fmtcls.__doc__.split('\n')[0]
            desc = textwrap.dedent(fmtcls.__doc__.rstrip())
            csp = sp.add_parser(name,
                                help=help_str,
                                description=desc,
                                formatter_class=argparse.RawDescriptionHelpFormatter)
            fmtcls.add_arguments(csp)
        return p