# Copyright 2018/2019 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import, division, print_function

import re

import numpy as np
from six.moves import xrange as range_
from rlgraph import get_backend
from rlgraph.spaces.space import Space
from rlgraph.utils.initializer import Initializer
from rlgraph.utils.util import convert_dtype

if get_backend() == "tf":
    import tensorflow as tf
if get_backend() == "pytorch":
    import torch


class BoxSpace(Space):
    """
    A box in R^n with a shape tuple of len n. Each dimension may be bounded.
    """
    def __init__(self, low, high, shape=None, add_batch_rank=False, add_time_rank=False, time_major=False,
                 dtype=np.float32):
        """
        Args:
            low (any): The lower bound (see Valid Inputs for more information).
            high (any): The upper bound (see Valid Inputs for more information).
            shape (tuple): The shape of this space.
            dtype (np.type): The data type (as numpy type) for this Space.
                Allowed are: np.int8,16,32,64, np.float16,32,64 and np.bool_.

        Valid inputs:
            BoxSpace(0.0, 1.0) # low and high are given as scalars and shape is assumed to be ()
                -> single scalar between low and high.
            BoxSpace(-1.0, 1.0, (3,4)) # low and high are scalars, and shape is provided -> nD array
                where all(!) elements are between low and high.
            BoxSpace(np.array([-1.0,-2.0]), np.array([2.0,4.0])) # low and high are arrays of the same shape
                (no shape given!) -> nD array where each dimension has different bounds.
        """
        super(BoxSpace, self).__init__(add_batch_rank=add_batch_rank, add_time_rank=add_time_rank,
                                       time_major=time_major)

        self.dtype = dtype

        # Determine the shape.
        if shape is None:
            if isinstance(low, (int, float, bool)):
                self._shape = ()
            else:
                self._shape = np.shape(low)
        else:
            assert isinstance(shape, (tuple, list)), "ERROR: `shape` must be None or a tuple/list."
            self._shape = tuple(shape)

        # Determine the bounds.
        # False if bounds are individualized (each dimension has its own lower and upper bounds and we can get
        # the single values from self.low and self.high), or a tuple of the globally valid low/high values that apply
        # to all values in all dimensions.
        # 0D Space.
        if self._shape == ():
            if isinstance(low, np.ndarray):
                assert low.shape == (), "ERROR: If shape == (), `low` must be scalar!"
                low = np.asscalar(low)
            if isinstance(high, np.ndarray):
                assert high.shape == (), "ERROR: If shape == (), `high` must be scalar!"
                high = np.asscalar(high)
            self.global_bounds = (low, high)
        # nD Space (n > 0). Bounds can be single number or individual bounds.
        else:
            # Low/high values are given individually per item.
            if isinstance(low, (list, tuple, np.ndarray)):
                self.global_bounds = False
            # Only one low/high value. Use these as generic bounds for all values.
            else:
                assert np.isscalar(low) and np.isscalar(high)
                self.global_bounds = (low, high)

        self.low = np.array(low)
        self.high = np.array(high)
        assert self.low.shape == self.high.shape

    def force_batch(self, samples, horizontal=None):
        assert self.has_time_rank is False, "ERROR: Cannot force a batch rank if Space `has_time_rank` is True!"
        # 0D (means: certainly no batch rank) or no extra rank given (compared to this Space), add a batch rank.
        if np.asarray(samples).ndim == 0 or \
                np.asarray(samples).ndim == len(self.get_shape(with_batch_rank=False, with_time_rank=False)):
            return np.array([samples]), True  # batch size=1
        # Samples is a list (whose len is interpreted as the batch size) -> return as np.array.
        elif isinstance(samples, list):
            return np.asarray(samples), False
        # Samples is already assumed to be batched. Return as is.
        return samples, False

    def get_shape(self, with_batch_rank=False, with_time_rank=False, time_major=None, **kwargs):
        batch_rank = ()
        if with_batch_rank is not False:
            # None shapes are typically only allowed in static graphs.
            if get_backend() == "tf":
                batch_rank = (((None,) if with_batch_rank is True else (with_batch_rank,))
                              if self.has_batch_rank else ())
            elif get_backend() == "pytorch":
                batch_rank = (((1,) if with_batch_rank is True else (with_batch_rank,))
                              if self.has_batch_rank else ())

        time_rank = ()
        if with_time_rank is not False:
            time_rank = (((None,) if with_time_rank is True else (with_time_rank,))
                          if self.has_time_rank else ())

        time_major = self.time_major if time_major is None else time_major
        if time_major is False:
            return batch_rank + time_rank + self.shape
        else:
            return time_rank + batch_rank + self.shape

    @property
    def flat_dim(self):
        return int(np.prod(self.shape))  # also works for shape=()

    @property
    def bounds(self):
        return self.low, self.high

    def tensor_backed_bounds(self):
        if get_backend() == "pytorch":
            return torch.tensor(self.low), torch.tensor(self.high)
        else:
            return self.low, self.high

    def get_variable(self, name, is_input_feed=False, add_batch_rank=None, add_time_rank=None,
                     time_major=None, is_python=False, local=False, **kwargs):
        add_batch_rank = self.has_batch_rank if add_batch_rank is None else add_batch_rank
        if add_batch_rank is False:
            batch_rank = ()
        elif add_batch_rank is True:
            batch_rank = (None,) if get_backend() == "tf" else (1,)
        else:
            batch_rank = (add_batch_rank,)

        add_time_rank = self.has_time_rank if add_time_rank is None else add_time_rank
        if add_time_rank is False:
            time_rank = ()
        elif add_time_rank is True:
            time_rank = (None,) if get_backend() == "tf" else (1,)
        else:
            time_rank = (add_time_rank,)
        time_major = self.time_major if time_major is None else time_major

        if time_major is False:
            shape = batch_rank + time_rank + self.shape
        else:
            shape = time_rank + batch_rank + self.shape

        if is_python is True or get_backend() == "python":
            if isinstance(add_batch_rank, int):
                if isinstance(add_time_rank, int) and add_time_rank > 0:
                    if time_major:
                        var = [[0 for _ in range_(add_batch_rank)] for _ in range_(add_time_rank)]
                    else:
                        print([0 for _ in range_(add_time_rank)])
                        var = [[0 for _ in range_(add_time_rank)] for _ in range_(add_batch_rank)]
                else:
                    var = [0 for _ in range_(add_batch_rank)]
            elif isinstance(add_time_rank, int) and add_time_rank > 0:
                var = [0 for _ in range_(add_time_rank)]
            else:
                var = []

            # Un-indent and just directly construct pytorch?
            if get_backend() == "pytorch" and is_input_feed:
                # Convert to PyTorch tensors as a faux placehodler.
                return torch.zeros(shape, dtype=convert_dtype(dtype=self.dtype, to="pytorch"))
            else:
                # TODO also convert?
                return var

        elif get_backend() == "tf":
            # TODO: re-evaluate the cutting of a leading '/_?' (tf doesn't like it)
            name = re.sub(r'^/_?', "", name)
            if is_input_feed:
                variable = tf.placeholder(dtype=convert_dtype(self.dtype), shape=shape, name=name)
                if self.has_batch_rank:
                    variable._batch_rank = self.has_batch_rank
                if self.has_time_rank:
                    variable._time_rank = self.has_time_rank
            else:
                init_spec = kwargs.pop("initializer", None)
                # Bools should be initializable via 0 or not 0.
                if self.dtype == np.bool_ and isinstance(init_spec, (int, float)):
                    init_spec = (init_spec != 0)

                if self.dtype == np.str_ and init_spec == 0:
                    initializer = None
                else:
                    initializer = Initializer.from_spec(shape=shape, specification=init_spec).initializer

                variable = tf.get_variable(
                    name, shape=shape, dtype=convert_dtype(self.dtype), initializer=initializer,
                    collections=[tf.GraphKeys.GLOBAL_VARIABLES if local is False else tf.GraphKeys.LOCAL_VARIABLES],
                    **kwargs
                )
            # Add batch/time rank flags to the op.
            if self.has_batch_rank:
                variable._batch_rank = 0 if self.time_major is False else 1
            if self.has_time_rank:
                variable._time_rank = 1 if self.time_major is False else 0
            return variable

    def zeros(self, size=None):
        return self.sample(size=size, fill_value=0)

    def contains(self, sample):
        sample_shape = sample.shape if not isinstance(sample, int) else ()
        if sample_shape != self.shape:
            return False
        return (sample >= self.low).all() and (sample <= self.high).all()

    def map(self, mapping):
        return mapping("", self)

    def __repr__(self):
        return "{}({} {} {}{})".format(
            type(self).__name__.title(), self.shape, str(self.dtype), "; +batch" if self.has_batch_rank else
            "", "; +time" if self.has_time_rank else ""
        )

    def __eq__(self, other):
        return isinstance(other, self.__class__) and \
               self.shape == other.shape and self.dtype == other.dtype
               # np.allclose(self.low, other.low) and np.allclose(self.high, other.high) and \

    def __hash__(self):
        if self.shape == () or self.global_bounds is not False:
            return hash((np.asscalar(self.low), np.asscalar(self.high)))
        return hash((tuple(self.low), tuple(self.high)))