"Classes for arrays where one or more axes can be indexed by Part objects"

from __future__ import division

import numpy as np
from openmodes.parts import Part
import numbers
import collections
import six
import warnings

def part_ranges_lowest(parent_part, basis_container):
    """Construct the slice objects for the parent part, iterating down only
    to the specified lowest level"""

    lowest = getattr(basis_container, 'lowest_parts', None)

    # get the size of each child part
    sizes = [len(basis_container[part]) for part in parent_part.iter_lowest(lowest)]

    offsets = np.cumsum([0]+sizes)

    ranges = {}

    # index of single parts to get the offsets
    lowest_part_num = 0

    # iterate with parents last, so they can get data from their child objects
    for part in parent_part.iter_lowest(lowest, parent_order='after'):
        if part in lowest:
            ranges[part] = slice(offsets[lowest_part_num], offsets[lowest_part_num+1])
            lowest_part_num += 1
            # take slice information from first and last child
            start = ranges[part.children[0]].start
            stop = ranges[part.children[-1]].stop
            ranges[part] = slice(start, stop)

    return ranges

def part_ranges(parent_part, basis_container):
    "Construct the slice objects for the parent part and all of its children"
    if hasattr(basis_container, 'lowest_parts'):
        return part_ranges_lowest(parent_part, basis_container)

    # get the size of each child part
    sizes = [len(basis_container[part]) for part in parent_part.iter_single()]
    offsets = np.cumsum([0]+sizes)

    ranges = {}

    # index of single parts to get the offsets
    single_part_num = 0

    # iterate with parents last, so they can get data from their child objects
    for part in parent_part.iter_all(parent_first=False):
        if hasattr(part, 'children'):
            # take slice information from first and last child
            start = ranges[part.children[0]].start
            stop = ranges[part.children[-1]].stop
            ranges[part] = slice(start, stop)
            ranges[part] = slice(offsets[single_part_num], offsets[single_part_num+1])
            single_part_num += 1

    return ranges

def build_lookup(index_data):
    "Create the lookup table for a LookupArray"
    lookup = []
    shape = []
    for index in index_data:
        if isinstance(index, collections.Iterable) and isinstance(index[0], Part):
            # a hierarchy of parts
            basis_container = index[1]
            part = index[0]
            ranges = part_ranges(part, basis_container)
            lookup.append((ranges, basis_container, part))
        elif isinstance(index, numbers.Integral):
            # an integer for a specific length
            # A tuple of strings representing quantities
            this_lookup = {}
            for (x, y) in enumerate(index):
                if not isinstance(y, six.string_types):
                    raise ValueError("Unknown index type %s" % str(y))
                this_lookup[y] = x


    return lookup, shape

def compatible_quantity_part(first, second):
    "Determine if two parts of lookup lists are compatible quantities and parts"

    a = first[0]
    b = second[0]
    if not (isinstance(a, dict) and isinstance(b, dict)):
        return False
    if set(a.values()) != set(b.values()):
        return False

    a = first[1]
    b = second[1]
    if not (isinstance(a[0], dict) and isinstance(b[0], dict)):
        return False

    return a[:2] == b[:2]

class LookupArray(np.ndarray):
    A subclass of a numpy array, where for certain dimensions, Part objects
    or strings can be used to index array elements.

    For explanation of subclassing numpy arrays, see

    The following functionality of numpy arrays may cause problems, so they
    should be used with "extreme caution":
        - transpose
        - adding new axes by indexing with np.newaxis/None
        - flattening
        - anything other than C ordering
        - Functions which reduce dimensions
        - Indexing with ...

    def __new__(subtype, index_data=None, lookup=None, shape=None, dtype=float):
        """Construct an empty vector which can be indexed by parts

        index_data : tuple, optional
            Tuple elements can be integer, for a fixed length,
            a tuple (Part, BasisContainer), for hierarchical indexing by Parts,
            or a tuple of strings, for quantities
        lookup, shape: tuple, optional
            Instead of providing index_data, these elements can be provided
            directly if the lookup table is known in advance
        dtype : dtype, optional
            The numpy data type of the vector

        if lookup is None or shape is None:
            lookup, shape = build_lookup(index_data)
        obj = np.ndarray.__new__(subtype, shape, dtype)
        obj.lookup = lookup

        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        "Function is called when creating array from view as well"
        if obj is None:

        # set default values for the custom attributes
        self.lookup = getattr(obj, 'lookup', None)

    def __setstate__(self, state):
        """Allow additional attributes of this array type to be unpickled

        Note that some metadata may be lost when unpickling."""
        base_state, extended_state = state
        super(LookupArray, self).__setstate__(base_state)
        self.lookup, = extended_state

    def __reduce__(self):
        """Allow additional attributes of this array type to be pickled

        Note that some metadata may be lost when unpickling."""
        base_reduce = list(super(LookupArray, self).__reduce__(self))
        full_state = (base_reduce[2], (self.lookup,))
        base_reduce[2] = full_state
        return tuple(base_reduce)

    # Under python 3.x, these members will not be called. However, they should
    # not cause any trouble.
    def __getslice__(self, start, stop):
        "Needed due to CPython bug"
        return self.__getitem__(slice(start, stop))

    def __setslice__(self, start, stop, val):
        "Needed due to CPython bug"
        self.__setitem__(slice(start, stop), val)

    def __getitem__(self, idx):
        """Gets an item or items from the array. Any of the indices may be the
        name of a range, in addition to all the usual fancy indexing options"""

        if not isinstance(idx, tuple):
            # force a single index to be a tuple
            idx = idx,

        new_idx = []
        sub_lookup = []
        entry_num = 0

        # try to lookup every part of the index to convert to a range
        for entry in idx:
            if isinstance(entry, Part):
                # Need to pass this metadata to the sub-array for its
                # lookup table
                this_lookup, container, parent_part = self.lookup[entry_num]
                sub_lookup.append((part_ranges(entry, container), container, entry))
            elif isinstance(entry, six.string_types):
                # If a string has been passed, then this dimension will have
                # been flattened out, so no metadata is needed
                this_lookup = self.lookup[entry_num]

                if not isinstance(entry, numbers.Integral):
                    # Integers mean a dimension is dropped, in all other
                    # cases it is kept
                    if entry is None:
                        # Need to record that a new dimension is added, so
                        # keep the place in the lookup of the original
                        entry_num -= 1
                    elif isinstance(entry, slice) and entry == slice(None):
                        # If slicing the whole dimension, metadata can be kept
                    elif isinstance(entry, collections.Iterable):
                        # TODO: find a better solution to avoid this probelm
                        # warnings.warn("Indexing LookupArray with iterable is unreliable")
                        # In all other cases metadata is lost

            entry_num += 1

        # now add lookup data for all the non-indexed dimensions
        sub_lookup = sub_lookup+self.lookup[entry_num:]

            result = super(LookupArray, self).__getitem__(tuple(new_idx))
        except IndexError as exc:
            message = "Invalid index %s" % str(idx)
            exc.args = (message,)+tuple(str(n) for n in exc.args[1:])

        # May get a LookupArray or an array scalar back
        if isinstance(result, LookupArray):
            result.lookup = sub_lookup

        return result

    def __setitem__(self, idx, value):
        """Gets an item or items in the array. Any of the indices may be the
        name of a range, in addition to all the usual fancy indexing options"""
        if not isinstance(idx, tuple):
            # force a single index to be a tuple
            idx = idx,

        new_idx = []

        # try to lookup every part of the index to convert to a range
        for entry_num, entry in enumerate(idx):
            if isinstance(entry, Part):
                this_lookup, container, parent_part = self.lookup[entry_num]
            elif isinstance(entry, six.string_types):
                this_lookup = self.lookup[entry_num]

            super(LookupArray, self).__setitem__(tuple(new_idx), value)
        except IndexError as exc:
            message = "Invalid index %s" % idx
            exc.args = (message,)+tuple(str(n) for n in exc.args[1:])

    def transpose(self, **args):
        raise NotImplementedError

    def T(self):
        result = super(LookupArray, self).T
        assert(type(result) == LookupArray)
        result.lookup = list(reversed(self.lookup))
        return result

    def simple_view(self):
        """Return a view where quantity dimensions (with string keys) are
        collapsed into the subsequent dimension. View is of type ndarray."""
        new_shape = []

        for dim_n, lu_n in zip(reversed(self.shape), reversed(self.lookup)):
            if type(lu_n) == dict and type(list(lu_n.keys())[0]) == str:
                new_shape[-1] *= dim_n
        return self.reshape(new_shape).view(np.ndarray)

    def dot(self, other):
        """Matrix/vector multiplication with another LookupArray"""
        if not isinstance(other, LookupArray):
            assert(self.shape[-1] == other.shape[0])
            new_lookup = self.lookup[:-1]+[None,]*(other.ndim-1)
            new_shape = self.shape[:-1]+other.shape[1:]
        elif compatible_quantity_part(self.lookup[-2:], other.lookup[:2]):
            new_lookup = self.lookup[:-2]+other.lookup[2:]
            new_shape = self.shape[:-2]+other.shape[2:]
            other = other.simple_view()
            raise NotImplementedError

        if len(new_shape) == 0:
            # handle the case of a scalar result
            new_array = np.dot(self.simple_view(), other)
            new_array = LookupArray(lookup=new_lookup, shape=new_shape,
                                    dtype=np.promote_types(self.dtype, other.dtype))
            new_array.simple_view()[:] = np.dot(self.simple_view(), other)
        return new_array

    def vdot(self, other):
        "Conjugated dot product"
        return self.conj().dot(other)

def view_lookuparray(original, index_data=None, lookup=None, shape=None):
    """Convert an array to a LookupArray, where possible avoiding copying"""
    if lookup is None:
        lookup, shape = build_lookup(index_data)
    result = original.reshape(shape).view(LookupArray)
    result.lookup = lookup
    return result

def loop_star_indices(x):
    """Return the indices into the array corresponding to the loop and star
    parts. The array must have been constructed using loop/star basis

    x: LookupArray
        Must have either 2 dimensions, both of which must be indexable
        by Parts

    indices_loop, indices_star: list of ndarray
        For each dimension n, indices_loop[n] is an array indexing the loop
        part, and indices_star[n] is an array indexing the star part.

    indices_loop = []
    indices_star = []

    for lookup_num, lookup in enumerate(x.lookup):
        if isinstance(list(lookup[0].keys())[0], Part):
            # This index is a lookup for Parts, so find all the SingleParts
            # along this index and add the relevant ranges to the indexing
            # array

            loop_list = []
            star_list = []
            # First find the parent part, the one covering the largest range
            part_list = list(lookup[0].keys())
            parent_part = part_list[np.argmax(lookup[n].stop-lookup[n].start
                                              for n in part_list)]

            # now iterate over all SingleParts of this parent part
            for part in parent_part.iter_single():
                part_start = lookup[0][part].start

                bf = lookup[1][part]
                loop_range = bf.loop_range

                star_range = bf.star_range

            # If this is not the last axis, then add the necessary number of
            # extra dimensions to each array so that they will be broadcast
            # correctly when the caller goes to use them
            new_shape = (-1,)+(1,)*(x.ndim-lookup_num-1)
            loop_array = np.hstack(loop_list).reshape(new_shape)
            star_array = np.hstack(star_list).reshape(new_shape)

            # This index is not for parts, so just take the whole axis

    return indices_loop, indices_star