# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import copy
from itertools import groupby
from operator import itemgetter
import sys

from pyspark.rdd import RDD
from pyspark_cassandra.conf import ReadConf, WriteConf
from pyspark_cassandra.format import ColumnSelector, RowFormat
from pyspark_cassandra.types import Row
from pyspark_cassandra.util import as_java_array, as_java_object, helper


if sys.version_info > (3,):
    long = int  # @ReservedAssignment


try:
    import pandas as pd  # @UnusedImport, import used in SpanningRDD
except:
    pass


def saveToCassandra(rdd, keyspace=None, table=None, columns=None, row_format=None, keyed=None,
                    write_conf=None, **write_conf_kwargs):
    '''
        Saves an RDD to Cassandra. The RDD is expected to contain dicts with keys mapping to CQL
        columns.

        Arguments:
        @param rdd(RDD):
            The RDD to save. Equals to self when invoking saveToCassandra on a monkey patched RDD.
        @param keyspace(string):in
            The keyspace to save the RDD in. If not given and the rdd is a CassandraRDD the same
            keyspace is used.
        @param table(string):
            The CQL table to save the RDD in. If not given and the rdd is a CassandraRDD the same
            table is used.

        Keyword arguments:
        @param columns(iterable):
            The columns to save, i.e. which keys to take from the dicts in the RDD.
            If None given all columns are be stored.

        @param row_format(RowFormat):
            Make explicit how to map the RDD elements into Cassandra rows.
            If None given the mapping is auto-detected as far as possible.
        @param keyed(bool):
            Make explicit that the RDD consists of key, value tuples (and not arrays of length
            two).

        @param write_conf(WriteConf):
            A WriteConf object to use when saving to Cassandra
        @param **write_conf_kwargs:
            WriteConf parameters to use when saving to Cassandra
    '''

    keyspace = keyspace or getattr(rdd, 'keyspace', None)
    if not keyspace:
        raise ValueError("keyspace not set")

    table = table or getattr(rdd, 'table', None)
    if not table:
        raise ValueError("table not set")

    # create write config as map
    write_conf = WriteConf.build(write_conf, **write_conf_kwargs)
    write_conf = as_java_object(rdd.ctx._gateway, write_conf.settings())
    # convert the columns to a string array
    columns = as_java_array(rdd.ctx._gateway, "String", columns) if columns else None

    helper(rdd.ctx) \
        .saveToCassandra(
            rdd._jrdd,
            keyspace,
            table,
            columns,
            row_format,
            keyed,
            write_conf,
        )


class _CassandraRDD(RDD):
    '''
        A Resilient Distributed Dataset of Cassandra CQL rows. As any RDD, objects of this class
        are immutable; i.e. operations on this RDD generate a new RDD.
    '''

    def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, **read_conf_kwargs):
        if not keyspace:
            raise ValueError("keyspace not set")

        if not table:
            raise ValueError("table not set")

        if row_format is None:
            row_format = RowFormat.ROW
        elif row_format < 0 or row_format >= len(RowFormat.values):
            raise ValueError("invalid row_format %s" % row_format)

        self.keyspace = keyspace
        self.table = table
        self.row_format = row_format
        self.read_conf = ReadConf.build(read_conf, **read_conf_kwargs)
        self._limit = None

        # this jrdd is for compatibility with pyspark.rdd.RDD
        # while allowing this constructor to be use for type checking etc
        # and setting _jrdd //after// invoking this constructor
        class DummyJRDD(object):
            def id(self):
                return -1
        jrdd = DummyJRDD()

        super(_CassandraRDD, self).__init__(jrdd, ctx)


    @property
    def _helper(self):
        return helper(self.ctx)


    def _pickle_jrdd(self):
        jrdd = self._helper.pickleRows(self._crdd, self.row_format)
        return self._helper.javaRDD(jrdd)


    def get_crdd(self):
        return self._crdd

    def set_crdd(self, crdd):
        self._crdd = crdd
        self._jrdd = self._pickle_jrdd()
        self._id = self._jrdd.id

    crdd = property(get_crdd, set_crdd)


    saveToCassandra = saveToCassandra


    def select(self, *columns):
        """Creates a CassandraRDD with the select clause applied."""
        columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in columns))
        return self._specialize('select', columns)


    def where(self, clause, *args):
        """Creates a CassandraRDD with a CQL where clause applied.
        @param clause: The where clause, either complete or with ? markers
        @param *args: The parameters for the ? markers in the where clause.
        """
        args = as_java_array(self.ctx._gateway, "Object", args)
        return self._specialize('where', *[clause, args])


    def limit(self, limit):
        """Creates a CassandraRDD with the limit clause applied."""
        self._limit = limit
        return self._specialize('limit', long(limit))


    def take(self, num):
        """Takes at most 'num' records from the Cassandra table.

        Note that if limit() was invoked before take() a normal pyspark take()
        is performed. Otherwise, first limit is set and _then_ a take() is
        performed.
        """
        if self._limit:
            return super(_CassandraRDD, self).take(num)
        else:
            return self.limit(num).take(num)


    def cassandraCount(self):
        """Lets Cassandra perform a count, instead of loading data to Spark"""
        return self._crdd.cassandraCount()


    def _specialize(self, func_name, *args, **kwargs):
        func = getattr(self._helper, func_name)

        new = copy(self)
        new.crdd = func(new._crdd, *args, **kwargs)

        return new


    def spanBy(self, *columns):
        """"Groups rows by the given columns without shuffling.

        @param *columns: an iterable of columns by which to group.

        Note that:
        -    The rows are grouped by comparing the given columns in order and
            starting a new group whenever the value of the given columns changes.
            This works well with using the partition keys and one or more of the
            clustering keys. Use rdd.groupBy(...) for any other grouping.
        -    The grouping is applied on the partition level. I.e. any grouping
            will be a subset of its containing partition.
        """

        return SpanningRDD(self.ctx, self._crdd, self._jrdd, self._helper, columns)


    def __copy__(self):
        c = self.__class__.__new__(self.__class__)
        c.__dict__.update(self.__dict__)
        return c



class CassandraTableScanRDD(_CassandraRDD):
    def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, **read_conf_kwargs):
        super(CassandraTableScanRDD, self).__init__(ctx, keyspace, table, row_format, read_conf,
                                                    **read_conf_kwargs)

        self._key_by = ColumnSelector.none()

        read_conf = as_java_object(ctx._gateway, self.read_conf.settings())

        self.crdd = self._helper \
            .cassandraTable(
                ctx._jsc,
                keyspace,
                table,
                read_conf,
            )


    def by_primary_key(self):
        return self.key_by(primary_key=True)

    def key_by(self, primary_key=True, partition_key=False, *columns):
        # TODO implement keying by arbitrary columns
        if columns:
            raise NotImplementedError('keying by arbitrary columns is not (yet) supported')
        if partition_key:
            raise NotImplementedError('keying by partition key is not (yet) supported')

        new = copy(self)
        new._key_by = ColumnSelector(partition_key, primary_key, *columns)
        new.crdd = self.crdd

        return new


    def _pickle_jrdd(self):
        # TODO implement keying by arbitrary columns
        jrdd = self._helper.pickleRows(self.crdd, self.row_format, self._key_by.primary_key)
        return self._helper.javaRDD(jrdd)



class SpanningRDD(RDD):
    '''
        An RDD which groups rows with the same key (as defined through named
        columns) within each partition.
    '''
    def __init__(self, ctx, crdd, jrdd, helper, columns):
        self._crdd = crdd
        self.columns = columns
        self._helper = helper

        rdd = RDD(jrdd, ctx).mapPartitions(self._spanning_iterator())
        super(SpanningRDD, self).__init__(rdd._jrdd, ctx)


    def _spanning_iterator(self):
        ''' implements basic spanning on the python side operating on Rows '''
        # TODO implement in Java and support not only Rows

        columns = set(str(c) for c in self.columns)

        def spanning_iterator(partition):
            def key_by(columns):
                for row in partition:
                    k = Row(**{c: row.__getattr__(c) for c in columns})
                    for c in columns:
                        del row[c]

                    yield (k, row)

            for g, l in groupby(key_by(columns), itemgetter(0)):
                yield g, list(_[1] for _ in l)

        return spanning_iterator


    def asDataFrames(self, *index_by):
        '''
            Reads the spanned rows as DataFrames if pandas is available, or as
            a dict of numpy arrays if only numpy is available or as a dict with
            primitives and objects otherwise.

            @param index_by If pandas is available, the dataframes will be
            indexed by the given columns.
        '''
        for c in index_by:
            if c in self.columns:
                raise ValueError('column %s cannot be used as index in the data'
                    'frames as it is a column by which the rows are spanned.')

        columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in self.columns))
        jrdd = self._helper.spanBy(self._crdd, columns)
        rdd = RDD(jrdd, self.ctx)

        global pd
        if index_by and pd:
            return rdd.mapValues(lambda _: _.set_index(*[str(c) for c in index_by]))
        else:
            return rdd


def joinWithCassandraTable(left_rdd, keyspace, table):
    '''
        Join an RDD with a Cassandra table on the partition key. Use .on(...)
        to specifiy other columns to join on. .select(...), .where(...) and
        .limit(...) can be used as well.

        Arguments:
        @param left_rdd(RDD):
            The RDD to join. Equals to self when invoking joinWithCassandraTable on a monkey
            patched RDD.
        @param keyspace(string):
            The keyspace to join on
        @param table(string):
            The CQL table to join on.
    '''

    return CassandraJoinRDD(left_rdd, keyspace, table)


class CassandraJoinRDD(_CassandraRDD):
    '''
        TODO
    '''

    def __init__(self, left_rdd, keyspace, table):
        super(CassandraJoinRDD, self).__init__(left_rdd.ctx, keyspace, table)
        self.crdd = self._helper \
            .joinWithCassandraTable(
                left_rdd._jrdd,
                keyspace,
                table
            )


    def on(self, *columns):
        columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in columns))
        return self._specialize('on', columns)