"""Ordered dictionary implementation.

"""

import collections as co
from itertools import count
from operator import eq
import sys

from sortedcontainers import SortedDict
from sortedcontainers.sortedlist import recursive_repr

if sys.hexversion < 0x03000000:
    from itertools import imap # pylint: disable=wrong-import-order, ungrouped-imports
    map = imap # pylint: disable=redefined-builtin, invalid-name

NONE = object()


class KeysView(co.KeysView):
    "Read-only view of mapping keys."
    # pylint: disable=too-few-public-methods
    def __reversed__(self):
        "``reversed(keys_view)``"
        return reversed(self._mapping)


class ItemsView(co.ItemsView):
    "Read-only view of mapping items."
    # pylint: disable=too-few-public-methods
    def __reversed__(self):
        "``reversed(items_view)``"
        for key in reversed(self._mapping):
            yield key, self._mapping[key]


class ValuesView(co.ValuesView):
    "Read-only view of mapping values."
    # pylint: disable=too-few-public-methods
    def __reversed__(self):
        "``reversed(values_view)``"
        for key in reversed(self._mapping):
            yield self._mapping[key]


class SequenceView(object):
    "Read-only view of mapping keys as sequence."
    # pylint: disable=too-few-public-methods
    def __init__(self, nums):
        self._nums = nums

    def __len__(self):
        "``len(sequence_view)``"
        return len(self._nums)

    def __getitem__(self, index):
        "``sequence_view[index]``"
        num = self._nums.iloc[index]
        return self._nums[num]


class OrderedDict(dict):
    """Dictionary that remembers insertion order and is numerically indexable.

    Keys are numerically indexable using the ``iloc`` attribute. For example::

        >>> ordered_dict = OrderedDict.fromkeys('abcde')
        >>> ordered_dict.iloc[0]
        'a'
        >>> ordered_dict.iloc[-2:]
        ['d', 'e']

    The ``iloc`` attribute behaves as a sequence-view for the mapping.

    """
    # pylint: disable=super-init-not-called
    def __init__(self, *args, **kwargs):
        self._keys = {}
        self._nums = nums = SortedDict()
        self._count = count()
        self.iloc = SequenceView(nums)
        self.update(*args, **kwargs)

    def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
        "``ordered_dict[key] = value``"
        if key not in self:
            num = next(self._count)
            self._keys[key] = num
            self._nums[num] = key
        dict_setitem(self, key, value)

    def __delitem__(self, key, dict_delitem=dict.__delitem__):
        "``del ordered_dict[key]``"
        dict_delitem(self, key)
        num = self._keys.pop(key)
        del self._nums[num]

    def __iter__(self):
        "``iter(ordered_dict)``"
        return self._nums.itervalues()

    def __reversed__(self):
        "``reversed(ordered_dict)``"
        nums = self._nums
        for key in reversed(nums):
            yield nums[key]

    def clear(self, dict_clear=dict.clear):
        "Remove all items from mapping."
        dict_clear(self)
        self._keys.clear()
        self._nums.clear()

    def popitem(self, last=True):
        """Remove and return (key, value) item pair.

        Pairs are returned in LIFO order if last is True or FIFO order if
        False.

        """
        index = -1 if last else 0
        num = self._nums.iloc[index]
        key = self._nums[num]
        value = self.pop(key)
        return key, value

    update = __update = co.MutableMapping.update

    def keys(self):
        "List of keys in mapping."
        return list(self.iterkeys())

    def items(self):
        "List of (key, value) item pairs in mapping."
        return list(self.iteritems())

    def values(self):
        "List of values in mapping."
        return list(self.itervalues())

    def iterkeys(self):
        "Return iterator over the keys in mapping."
        return self._nums.itervalues()

    def iteritems(self):
        "Return iterator over the (key, value) item pairs in mapping."
        for key in self._nums.itervalues():
            yield key, self[key]

    def itervalues(self):
        "Return iterator over the values in mapping."
        for key in self._nums.itervalues():
            yield self[key]

    def viewkeys(self):
        "Return set-like object with view of mapping keys."
        return KeysView(self)

    def viewitems(self):
        "Return set-like object with view of mapping items."
        return ItemsView(self)

    def viewvalues(self):
        "Return object with view of mapping values."
        return ValuesView(self)

    def pop(self, key, default=NONE):
        """Remove given key and return corresponding value.

        If key is not found, default is returned if given, otherwise raise
        KeyError.

        """
        if key in self:
            value = self[key]
            del self[key]
            return value
        elif default is NONE:
            raise KeyError(key)
        else:
            return default

    def setdefault(self, key, default=None):
        """Return ``mapping.get(key, default)``, also set ``mapping[key] = default`` if
        key not in mapping.

        """
        if key in self:
            return self[key]
        self[key] = default
        return default

    @recursive_repr
    def __repr__(self):
        "Text representation of mapping."
        return '%s(%r)' % (self.__class__.__name__, self.items())

    __str__ = __repr__

    def __reduce__(self):
        "Support for pickling serialization."
        return (self.__class__, (self.items(),))

    def copy(self):
        "Return shallow copy of mapping."
        return self.__class__(self)

    @classmethod
    def fromkeys(cls, iterable, value=None):
        """Return new mapping with keys from iterable.

        If not specified, value defaults to None.

        """
        return cls((key, value) for key in iterable)

    def __eq__(self, other):
        "Test self and other mapping for equality."
        if isinstance(other, OrderedDict):
            return dict.__eq__(self, other) and all(map(eq, self, other))
        else:
            return dict.__eq__(self, other)

    __ne__ = co.MutableMapping.__ne__

    def _check(self):
        "Check consistency of internal member variables."
        # pylint: disable=protected-access
        keys = self._keys
        nums = self._nums

        for key, value in keys.items():
            assert nums[value] == key

        nums._check()