import copy
import numpy as np
import random

from schema_games.utils import \
    compute_edge_nzis, compute_shape_from_nzis, \
    get_distinct_colors, shape_to_nzis, offset_nzis_from_position
from schema_games.breakout.constants import \
    _MAX_SPEED, _BALL_SHAPE, CLASSIC_BALL_COLOR, DEFAULT_PADDLE_COLOR, \
    CLASSIC_WALL_COLOR, MAX_NZIS_PER_ENTITY, CLASSIC_BRICK_COLORS


###############################################################################
# API
###############################################################################

class BreakoutObject(object):
    """
    Base class for objects in BreakoutEngine.

    Parameters
    ----------
    position : [int, int] or (int, int) or numpy.ndarray
        Initial position coordinates of the object.
    nzis : [(int, int)] or None
        Nonzero indices for the parts/pixels of the object. (0, 0) corresponds
        to object position and is located at the top left corner of the object.
    shape : (int, int) or None
        Convenience constructor for rectangular objects. Generates the correct
        nzis for a rectangle that shape. When not None, nzis has to be None.
    hitpoints : int
        Number of hitpoints that object has. Note that in order for special
        behavior to happen when that number reaches a special value (e.g.,
        destroy object when hitpoints = 0), that behavior has to be encoded in
        a special _destruction_effect(self, env) method where env is an
        instance of BreakoutEngine.
    is_entity : bool
        If True, the engine reports the object position as an entity.
    color : (int, int, int)
        Color of the object, to be used when rendering the environment and to
        be returned as an observable visual attribute if applicable. Each value
        is an int between 0 and 255 (included).
    visible : bool
        If False, the object is not rendered in the visual output, and not
        considered for physical interaction (e.g., a ball will not bounce off
        of it). This is useful to temporarily remove an object that might
        reappear later.
    entity_id : None or int
        Unique identifier of this entity used when parsing the game image into
        primitive entities (based on shape and color) in the proxy vision
        system. This is useful if we don't want to bother trying to track each
        of the entities by solving a binding problem between frames.
    indirect_collision_effects : bool
        If True, then indirect collisions trigger effects for this object; i.e.
        the special methods presented below may be called during a indirect
        collision. May be set to False in subclasses that yield rewards in
        order to avoid reward entanglement during indirect collisions. Note
        that even if this option is set to False, indirect collisions are still
        taken into account to determine the ball's bounce trajectory.

    Special methods
    ---------------
    '_collision_effect' will be called when a ball collides with this object,
    whether it destroys it or not.

    '_destruction_effect' will be called when a ball collides with this object
    and destroys it.
    """
    unique_entity_id = 0
    unique_object_id = 0
    unique_color_id = 0

    # Map from RGB color to a unique id for that color.
    color_map = {}

    def __init__(self,
                 position,
                 nzis=None,
                 shape=None,
                 hitpoints=np.PINF,
                 is_entity=True,
                 color=(128, 128, 128),
                 visible=True,
                 indirect_collision_effects=True):

        self._position = np.array(position)
        self.hitpoints = hitpoints
        self.is_entity = is_entity
        self.color = color
        self.visible = visible
        self.is_rectangular = True
        self.indirect_collision_effects = indirect_collision_effects

        assert self.hitpoints > 0
        assert ((nzis is None and shape is not None) or
                (nzis is not None and shape is None))

        # Set non-zero indices of the object mask's
        if nzis is None:
            self._nzis = shape_to_nzis(shape)
        else:
            self._nzis = np.array(nzis)

        if is_entity:
            self.entity_id = BreakoutObject.unique_entity_id
            BreakoutObject.unique_entity_id += MAX_NZIS_PER_ENTITY
            assert len(self._nzis) <= MAX_NZIS_PER_ENTITY
        else:
            self.entity_id = None

        self.object_id = BreakoutObject.unique_object_id
        BreakoutObject.unique_object_id += 1
        BreakoutObject.register_color(self.color)

        # Sets up slots for memoization
        self.reset_cache()

    def reset_cache(self):
        """
        If the shape changes for any reason, we need to reset cached values.
        Caching these values is useful to reduce overhead as the game engine
        looks up these properties frequently.
        """
        self._cached_shape = None
        self._cached_offset_nzis = None
        self._cached_offset_edge_nzis = None
        self._cached_nzis_min = None
        self._cached_nzis_max = None

    ###########################################################################
    # Protected attributes: setting any of them triggers a cache refresh
    ###########################################################################

    @property
    def position(self):
        return self._position

    @position.setter
    def position(self, pos):
        if all(np.array(pos) == self._position):
            return
        try:
            self._position = np.array(pos)
        except:
            raise
        finally:
            self.reset_cache()

    @property
    def nzis(self):
        return self._nzis

    @nzis.setter
    def nzis(self, new_nzis):
        try:
            self._nzis = new_nzis
        except:
            raise
        finally:
            self.reset_cache()

    ###########################################################################
    # Read-only, cached attributes that derived from `nzis`
    ###########################################################################

    @property
    def offset_edge_nzis(self):
        """
        Return only the non-zero indices on the edge / boundary of the object,
        and offset them by the position of the object within the game frame.
        """
        if self._cached_offset_edge_nzis is None:
            self._cached_offset_edge_nzis = offset_nzis_from_position(
                compute_edge_nzis(self._nzis), self.position)
        return self._cached_offset_edge_nzis

    @offset_edge_nzis.setter
    def offset_edge_nzis(self, other):
        raise RuntimeError("Setting `offset_edge_nzis` is not supported.")

    @property
    def offset_nzis(self):
        """
        Return the non-zero indices of the object, offset by the position of
        the object within the game frame.
        """
        if self._cached_offset_nzis is None:
            self._cached_offset_nzis = \
                offset_nzis_from_position(self._nzis, self._position)
        return self._cached_offset_nzis

    @offset_nzis.setter
    def offset_nzis(self, other):
        raise RuntimeError("Setting `offset_nzis` is not supported.")

    @property
    def shape(self):
        """
        Shape of the object's bounding box.
        """
        if self._cached_shape is None:
            self._cached_shape = compute_shape_from_nzis(self._nzis)
        return self._cached_shape

    @shape.setter
    def shape(self, other):
        raise RuntimeError("Setting `shape` is not supported.")

    @property
    def nzis_min(self):
        """
        Easy optimization if the object is known to be rectangular.
        """
        if self._cached_nzis_min is None:
            self._cached_nzis_min = self._nzis.min(axis=0)
        return self._cached_nzis_min

    @nzis_min.setter
    def nzis_min(self, other):
        raise RuntimeError("Setting `nzis_min` is not supported.")

    @property
    def nzis_max(self):
        """
        Easy optimization if the object is known to be rectangular.
        """
        if self._cached_nzis_max is None:
            self._cached_nzis_max = self._nzis.max(axis=0)
        return self._cached_nzis_max

    @nzis_max.setter
    def nzis_max(self, other):
        raise RuntimeError("Setting `nzis_max` is not supported.")

    ###########################################################################
    # Other properties
    ###########################################################################

    def contains_position(self, pos):
        """
        Does this object contain that position?

        Parameters
        ----------
        pos : (int, int)
            Some position in the game.

        Returns
        -------
        bool
        """
        if self.is_rectangular:
            return self.contains_position_within_bounding_box(pos)
        else:
            return pos in self.offset_nzis

    def contains_position_within_bounding_box(self, pos):
        """
        Check whether some position lies within the object's bounding box.
        We are using this for collision detection, so this is assuming the
        object is square.
        """
        if pos[0] < self.position[0] + self.nzis_min[0]:
            return False
        if pos[0] > self.position[0] + self.nzis_max[0]:
            return False
        if pos[1] < self.position[1] + self.nzis_min[1]:
            return False
        if pos[1] > self.position[1] + self.nzis_max[1]:
            return False
        return True

    def _collision_effect(self, environment):
        """
        Called by the environment when the ball collides with this object
        without destroying it.

        Parameters
        ----------
        environment : BreakoutEngine
            Instance of the game that called that method in the first place. It
            is passed as an argument to allow it to be changed when a collision
            with self occurs.
        """
        pass

    def _destruction_effect(self, environment):
        """
        Called by the environment when the ball collides with this object and
        destroys it.

        Parameters
        ----------
        environment : BreakoutEngine
            Instance of the game that called that method in the first place. It
            is passed as an argument to allow it to be changed when a collision
            with self occurs.
        """
        pass

    @classmethod
    def register_color(cls, color):
        """
        Registers unique colors within the class attribute `BreakoutObject.
        color_map`. This is a legacy attribute that we do not use anywhere, but
        we keep it as it might be useful for debugging.
        """
        if color not in BreakoutObject.color_map:
            BreakoutObject.color_map[color] = BreakoutObject.unique_color_id
            BreakoutObject.unique_color_id += 1


class MoveableObject(BreakoutObject):
    """
    Base class for all objects whose position may change.
    """
    pass


class MomentumObject(MoveableObject):
    """
    Base class for all objects with action-independent momentum.
    """
    pass


###############################################################################
# Different kinds of bricks
###############################################################################

class Brick(BreakoutObject):
    """
    Base class for Breakout bricks. It has a custom attribute 'reward' that is
    only accessed within this class.
    """
    def __init__(self, *args, **kwargs):
        """
        Parameters
        ----------
        reward : int
            Reward upon collision.
        """
        kwargs.setdefault('color', random.choice(get_distinct_colors(6)))
        kwargs.setdefault('hitpoints', 1)
        kwargs.setdefault('indirect_collision_effects', False)
        self.reward = kwargs.pop('reward', 1)

        super(Brick, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        self.hitpoints -= 1

    def _destruction_effect(self, environment):
        environment.reward += self.reward
        environment.brick_hit_counter += 1
        environment.bricks.remove(self)

    @staticmethod
    def brick_colors_classic(num_bricks):
        """
        Helper function.
        """
        if num_bricks < len(CLASSIC_BRICK_COLORS):
            return (CLASSIC_BRICK_COLORS[:-2] +
                    [CLASSIC_BRICK_COLORS[-1]])
        else:
            return CLASSIC_BRICK_COLORS


class StrongBrick(Brick):
    """
    Bricks that take multiple hits to be destroyed.
    """
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('hitpoints', 3)
        super(StrongBrick, self).__init__(*args, **kwargs)

        self.init_hitpoints = copy.copy(self.hitpoints)
        self.reward = copy.copy(self.hitpoints)
        self.init_color = copy.copy(self.color)

    def _collision_effect(self, environment):
        """
        Dim brick color
        """
        self.hitpoints -= 1

        # Change color
        coef = float(self.hitpoints) / self.init_hitpoints
        self.color = (
                int(self.init_color[0] * coef),
                int(self.init_color[1] * coef),
                int(self.init_color[2] * coef),
            )

    def _destruction_effect(self, environment):
        environment.reward += self.reward
        environment.brick_hit_counter += 1
        environment.bricks.remove(self)


class PaddleShrinkingBrick(Brick):
    """
    Bricks that shrink the paddle when hit.
    """
    def __init__(self, *args, **kwargs):
        """
        Parameters
        ----------
        shrinkage : int
            Amount by which to shrink the paddle.
        """
        kwargs.setdefault('color', (242, 79, 34))
        self.shrinkage = kwargs.pop('shrinkage', 4)
        super(PaddleShrinkingBrick, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        """
        Shrink paddle
        """
        self.hitpoints -= 1
        environment.paddle.shrink(self.shrinkage, min_length=1)

    def _destruction_effect(self, environment):
        environment.brick_hit_counter += 1
        environment.bricks.remove(self)


class PaddleGrowingBrick(Brick):
    """
    Bricks that grow the paddle when hit.
    """
    def __init__(self, *args, **kwargs):
        """
        Parameters
        ----------
        growth : int
            Amount by which to grow the paddle.
        """
        kwargs.setdefault('color', (34, 197, 242))
        self.growth = kwargs.pop('growth', 4)
        super(PaddleGrowingBrick, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        """
        Grow paddle
        """
        self.hitpoints -= 1
        max_length = environment.width - 2 * environment.wall_thickness
        environment.paddle.grow(self.growth, max_length=max_length)

    def _destruction_effect(self, environment):
        environment.brick_hit_counter += 1
        environment.bricks.remove(self)


class AcceleratorBrick(Brick):
    """
    Bricks that permanently accelerate the ball when hit.
    """
    trigger_counter = _MAX_SPEED - 1

    def __init__(self, *args, **kwargs):
        super(AcceleratorBrick, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        """
        Accelerate ball. To be triggered only a limited number of times.
        """
        self.hitpoints -= 1

        if AcceleratorBrick.trigger_counter > 0:
            new_bmr = environment.ball_movement_radius + 1
            environment.ball_movement_radius = min(new_bmr, _MAX_SPEED)
            AcceleratorBrick.trigger_counter -= 1

    def _destruction_effect(self, environment):
        environment.reward += self.reward
        environment.brick_hit_counter += 1
        environment.bricks.remove(self)


class ResetterBrick(Brick):
    """
    Upon collision, this brick yields some reward and then calls the
    environment's layout function again to reset the game.
    """
    def __init__(self, *args, **kwargs):
        super(ResetterBrick, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        environment.reward += self.reward
        environment.randomize_bricks_positions()

    def _destruction_effect(self, environment):
        pass


###############################################################################
# Other entities: paddle, ball, wall, etc.
###############################################################################

class Paddle(BreakoutObject):
    """
    Paddle. Note that ball-paddle collisions will *not* trigger a call of
    Paddle._collision_effect!
    """
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('color', DEFAULT_PADDLE_COLOR)
        super(Paddle, self).__init__(*args, **kwargs)

    def shrink(self, amount, min_length):
        dx, dy = self.shape
        new_shape = (max(self.shape[0] - amount, min_length), dy)
        self.nzis = shape_to_nzis(new_shape)

    def grow(self, amount, max_length):
        dx, dy = self.shape
        new_shape = (min(self.shape[0] + amount, max_length), dy)
        self.nzis = shape_to_nzis(new_shape)

    def _collision_effect(self, environment):
        raise RuntimeError("Paddle-bound collisions not to be handled here.")

    def _destruction_effect(self, environment):
        raise RuntimeError("Paddle-bound collisions not to be handled here.")


class Ball(MomentumObject):
    """
    Ball. Unlike MomentumObject, it has a special attribute, velocity_index,
    that determines its velocity.
    """
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('color', CLASSIC_BALL_COLOR)

        # Force ball shape since it's an immutable parameter
        assert not {'shape', 'nzis'}.intersection(kwargs.keys())
        kwargs['shape'] = _BALL_SHAPE
        kwargs['nzis'] = None
        self.velocity_index = kwargs.pop('velocity_index', None)

        super(Ball, self).__init__(*args, **kwargs)


class Wall(BreakoutObject):
    """
    Wall. It has shape (1, 1) by default.
    """
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('color', CLASSIC_WALL_COLOR)

        if 'nzis' not in kwargs and 'shape' not in kwargs:
            kwargs['nzis'] = [(0, 0)]

        super(Wall, self).__init__(*args, **kwargs)

        # Things get complicated if walls can be invisible. Protect this.
        assert self.visible


class PaddleShrinkingWall(Wall):
    """
    Wall that reduces the size of the paddle when hit once. All instances share
    a class attribute 'trigger_count' to make sure that the effect be triggered
    only once.
    """
    trigger_count = False

    def __init__(self, *args, **kwargs):
        super(PaddleShrinkingWall, self).__init__(*args, **kwargs)

    def _collision_effect(self, environment):
        """
        Shrink paddle.
        """
        if PaddleShrinkingWall.trigger_count > 0:
            environment.paddle.shrink(environment.paddle.shape[0] // 2)
            PaddleShrinkingWall.trigger_count -= 1


class WallOfPunishment(Wall):
    """
    Virtual wall that is equivalent to losing a ball. Useful when training
    agents. Note that no collision effects are triggered and that rewards are
    handled by the engine normally.
    """
    def __init__(self, *args, **kwargs):
        """
        No need to do anything here, the reward is handled below.
        """
        super(WallOfPunishment, self).__init__(*args, **kwargs)

        # The whole point of this is to have an entity that makes negative
        # rewards given ball loss learnable by an entity-based agent ...
        self.is_entity = True
        self.visible = False


class HorizontallyMovingObstacle(MomentumObject):
    """
    Wall that bounces back and forth. Note that velocity is encoded differently
    than the ball, which is a special case.
    """
    def __init__(self, *args, **kwargs):
        """
        Parameters
        ----------
        velocity : (int, int)
            Initial velocity of the object.
        """
        self.velocity = kwargs.pop('velocity')
        assert self.velocity[1] == 0
        super(HorizontallyMovingObstacle, self).__init__(*args, **kwargs)