import math
from xml.dom.minidom import Document,Node

class Distribution(dict):
    A probability distribution over hashable objects
    .. warning:: If you make the domain values mutable types, try not to change their values while they are inside the distribution.  If you must change a domain value, it is better to first delete the old value, change it, and then re-insert it.
    epsilon = 1e-8

    def __new__(cls,args=None,rationality=None):
        obj = dict.__new__(cls)
        obj._domain = {}
        return obj
    def __init__(self,args=None,rationality=None):
        :param args: the initial elements of the probability distribution
        :type args: dict
        :param rationality: if not ``None``, then use as a rationality parameter in a quantal response over the provided values
        :type rationality: float
#        self._domain = {}
        if isinstance(args,Node):
        elif isinstance(args,Distribution):
            # Some other distribution given
            for key in args.domain():
                self[key] = args[key]
        elif isinstance(args,dict):
            if rationality is None:
                # Probability dictionary provided
                for key,value in args.items():
                    self[key] = value
                # Do quantal response / softmax on table of values
                for key,V in args.items():
                    self[key] = math.exp(rationality*V)

    def first(self):
        :returns: the first element in this distribution's domain (most useful if there's only one element)
        return next(iter(self.domain()))

    def get(self,element):
        key = hash(element)
        return dict.get(self,key,0.)
    def __getitem__(self,element):
        key = hash(element)
        return dict.__getitem__(self,key)
    def __setitem__(self,element,value):
        :param element: the domain element
        :param value: the probability to associate with the given key
        :type value: float
        key = hash(element)
        self._domain[key] = element

    def items(self):
        for key,value in dict.items(self):
            yield self._domain[key],value

    def addProb(self,element,value):
        Utility method that increases the probability of the given element by the given value
        key = hash(element)
        if key in self._domain:
            self._domain[key] = element
#        try:
#            self[element] += value
#        except KeyError:
#            self[element] = value

    def getProb(self,element):
        Utility method that is almost identical to __getitem__, except that it returns 0 for missing elements, instead of throwing a C{KeyError}
            return self[element]
        except KeyError:
            return 0.

    def __delitem__(self,element):
        key = hash(element)
        del self._domain[key]

    def clear(self):

    def replace(self,old,new):
        """Replaces on element in the sample space with another.  Raises an exception if the original element does not exist, and an exception if the new element already exists (i.e., does not do a merge)
        prob = self[old]
        del self[old]
        self[new] = prob
    def domain(self):
        :returns: the sample space of this probability distribution
        :rtype: list
        return list(self._domain.values())

    def normalize(self):
        """Normalizes the distribution so that the sum of values = 1
        total = sum(self.values())
        if abs(total-1.) > self.epsilon:
            for key in self.domain():
                    self[key] /= total
                except ZeroDivisionError:
                    self[key] = 1./float(len(self))
    def expectation(self):
        :returns: the expected value of this distribution
        :rtype: float
        if len(self) == 1:
            # Shortcut if no uncertainty
            return self.domain()[0]
            total = None
            for element in self.domain():
                if total is None:
                    total = element*self[element]
                    total += element*self[element]
            return total

    def __float__(self):
        return self.expectation()
    def sample(self,quantify=False):
        :param quantify: if ``True``, also returns the amount of mass by which the sampling crosssed the threshold of the generated sample's range
        :returns: an element from this domain, with a sample probability given by this distribution
        import random
        selection = random.uniform(0.,sum(self.values()))
        original = selection
        for element in self.domain():
            if selection > self[element]:
                selection -= self[element]
                if quantify:
                    return element,selection
                    return element
        # We shouldn't get here. But in case of some floating-point weirdness?
        return element

    def set(self,element):
        Reduce distribution to be 100% for the given element
        :param element: the element that will be the only one with nonzero probability
        self[element] = 1.

    def select(self,maximize=False):
        Reduce distribution to a single element, sampled according to the given distribution
        :returns: the probability of the selection made
        if maximize:
            element = self.max()
            element = self.sample()
        prob = self[element]
        return prob

    def max(self):
        :returns: the most probable element in this distribution (breaking ties by returning the highest-valued element)
        return self._domain[max([(dict.__getitem__(self,element),element) for element in self._domain])[1]]

    def entropy(self):
        :returns: entropy (in bits) of this distribution
        return sum([-p*math.log2(p) for p in dict.values(self)])

    def __add__(self,other):
        if isinstance(other,Distribution):
            result = self.__class__()
            for me in self.domain():
                for you in other.domain():
            return result
            result = self.__class__()
            for element in self.domain():
            return result

    def __sub__(self,other):
        return self + (-other)

    def __neg__(self):
        result = self.__class__()
        for element in self.domain():
        return result

    def __mul__(self,other):
        if isinstance(other,Distribution):
            raise NotImplementedError('Unable to multiply %s by %s.' \
                                      % (self.__class__.__name__,other.__class__.__name__))
            result = self.__class__()
            for element in self.domain():
            return result
    def prune(self,epsilon=1e-8):
        elements = self.domain()
        i = 0
        while i < len(self)-1:
            el1 = elements[i]
            j = i+1
            while j < len(self):
                el2 = elements[j]
                if abs(el1-el2) < epsilon:
                    self[el1] += self[el2]
                    del self[el2]
                    del elements[j]
                    j += 1
            i += 1

    def __xml__(self):
        :returns: An XML Document object representing this distribution
        doc = Document()
        root = doc.createElement('distribution')
        for key,value in self._domain.items():
            prob = dict.__getitem__(self,key)
            node = doc.createElement('entry')
#            if key != hash(value):
#                node.setAttribute('key',key)
            if isinstance(value,str):
        return doc
    def element2xml(self,value):
        raise NotImplementedError('Unable to generate XML for distributions over %s' % (value.__class__.__name__))

    def parse(self,element):
        """Extracts the distribution from the given XML element
        :param element: The XML Element object specifying the distribution
        :type element: Element
        :returns: This L{Distribution} object"""
        assert element.tagName == 'distribution','Unexpected tag %s for %s' \
            % (element.tagName,self.__class__.__name__)
        node = element.firstChild
        while node:
            if node.nodeType == node.ELEMENT_NODE:
                prob = float(node.getAttribute('probability'))
                value = str(node.getAttribute('key'))
                if not value:
                    subNode = node.firstChild
                    while subNode and subNode.nodeType != subNode.ELEMENT_NODE:
                        subNode = subNode.nextSibling
                    value = self.xml2element(None,subNode)
                self[value] = prob
#                if not key:
#                    key = str(value)
#                dict.__setitem__(self,key,prob)
#                self._domain[key] = value
            node = node.nextSibling

    def xml2element(self,key,node):
        return key

    def sortedString(self):
        elements = self.domain()
        elements.sort(lambda x,y: cmp(str(x),str(y)))
        return '\n'.join(['%4.1f%%\t%s' % (100.*self[el],str(el)) for el in elements])

    def __str__(self):
        return '\n'.join(['%d%%\t%s' % (100*self[el],str(el).replace('\n','\n\t'))
                          for el in self._domain.values()])
#        return '\n'.join(map(lambda el: '%d%%\t%s' % (100.*self[el],str(el).replace('\n','\n\t')),self.domain()))

    def __hash__(self):
        return hash(str(self))

    def __copy__(self):
        return self.__class__(self.__xml__().documentElement)

    def __getstate__(self):
        return {el: self[el] for el in self.domain()}
    def __setstate__(self,state):
        for el,prob in state.items():
            self[el] = prob