"""
Tile Coding Software version 3.0beta
by Rich Sutton
based on a program created by Steph Schaeffer and others External documentation
and recommendations on the use of this code is available in the reinforcement
learning textbook by Sutton and Barto, and on the web. These need to be
understood before this code is.

This software is for Python 3 or more.

This is an implementation of grid-style tile codings, based originally on the
UNH CMAC code (see http://www.ece.unh.edu/robots/cmac.htm), but by now highly
changed. Here we provide a function, "tiles", that maps floating and integer
variables to a list of tiles, and a second function "tiles-wrap" that does the
same while wrapping some floats to provided widths (the lower wrap value is
always 0).

The float variables will be gridded at unit intervals, so generalization will
be by approximately 1 in each direction, and any scaling will have to be done
externally before calling tiles.

Num-tilings should be a power of 2, e.g., 16. To make the offsetting work
properly, it should also be greater than or equal to four times the number of
floats.

The first argument is either an index hash table of a given size (created by
(make-iht size)), an integer "size" (range of the indices from 0), or nil (for
testing, indicating that the tile coordinates are to be returned without being
converted to indices).

NOTE: @ShangtongZhang made some minor changes to make it work in Python 2
source:
https://github.com/ShangtongZhang/reinforcement-learning-an-introduction
http://incompleteideas.net/sutton/tiles/tiles3.html
"""

from math import floor, log
from six.moves import zip_longest

basehash = hash


class IHT:
    '''The index hash table is a structure to handle collisions'''
    def __init__(self, sizeval):
        self.size = sizeval
        self.overfullCount = 0
        self.dictionary = {}

    def __str__(self):
        '''Prepares a string for printing whenever this object is printed'''
        return "Collision table:" + \
               " size:" + str(self.size) + \
               " overfullCount:" + str(self.overfullCount) + \
               " dictionary:" + str(len(self.dictionary)) + " items"

    def count(self):
        return len(self.dictionary)

    def fullp(self):
        return len(self.dictionary) >= self.size

    def getindex(self, obj, readonly=False):
        d = self.dictionary
        if obj in d:
            return d[obj]
        elif readonly:
            return None
        size = self.size
        count = self.count()
        if count >= size:
            if self.overfullCount == 0:
                print('IHT full, starting to allow collisions')
            self.overfullCount += 1
            return basehash(obj) % self.size
        else:
            d[obj] = count
            return count


def hashcoords(coordinates, m, readonly=False):
    if isinstance(m, IHT):
        return m.getindex(tuple(coordinates), readonly)
    if isinstance(m, int):
        return basehash(tuple(coordinates)) % m
    if m is None:
        return coordinates


def tiles(ihtORsize, numtilings, floats, ints=[], readonly=False):
    '''returns num-tilings tile indices corresponding to the floats and ints'''
    qfloats = [floor(f*numtilings) for f in floats]
    Tiles = []
    for tiling in range(numtilings):
        tilingX2 = tiling*2
        coords = [tiling]
        b = tiling
        for q in qfloats:
            coords.append((q + b) // numtilings)
            b += tilingX2
        coords.extend(ints)
        Tiles.append(hashcoords(coords, ihtORsize, readonly))
    return Tiles


def tileswrap(ihtORsize, numtilings, floats, wrapwidths, ints=[],
              readonly=False):
    '''
    returns num-tilings tile indices corresponding to the floats and ints,
    wrapping some floats

    :param ihtORsize: integer or IHT object. An index hash table or a positive
        integer specifying the upper range of returned indices
    :param numtilings: integer. the number of tilings desired. For best
        results, the second argument, numTilings, should be a power of two
        greater or equal to four times the number of floats
    :param memory-size: ineteger. the number of possible tile indices
    :param floats: list. a list of real values making up the input vector
    :param wrapwidths:
    :param ints*: list. optional list of integers to get different hashings
    :param readonly*: boolean.
    '''
    qfloats = [floor(f*numtilings) for f in floats]
    Tiles = []
    for tiling in range(numtilings):
        tilingX2 = tiling*2
        coords = [tiling]
        b = tiling
        for q, width in zip_longest(qfloats, wrapwidths):
            c = (q + b % numtilings) // numtilings
            coords.append(c % width if width else c)
            b += tilingX2
        coords.extend(ints)
        Tiles.append(hashcoords(coords, ihtORsize, readonly))
    return Tiles