```import math
from xml.dom.minidom import Document,Node

class Distribution(dict):
"""
A probability distribution over hashable objects

.. warning:: If you make the domain values mutable types, try not to change their values while they are inside the distribution.  If you must change a domain value, it is better to first delete the old value, change it, and then re-insert it.
"""
epsilon = 1e-8

def __new__(cls,args=None,rationality=None):
obj = dict.__new__(cls)
obj._domain = {}
return obj

def __init__(self,args=None,rationality=None):
"""
:param args: the initial elements of the probability distribution
:type args: dict
:param rationality: if not ``None``, then use as a rationality parameter in a quantal response over the provided values
:type rationality: float
"""
dict.__init__(self)
#        self._domain = {}
if isinstance(args,Node):
self.parse(args)
elif isinstance(args,Distribution):
# Some other distribution given
for key in args.domain():
self[key] = args[key]
elif isinstance(args,dict):
if rationality is None:
# Probability dictionary provided
for key,value in args.items():
self[key] = value
else:
# Do quantal response / softmax on table of values
for key,V in args.items():
self[key] = math.exp(rationality*V)
self.normalize()

def first(self):
"""
:returns: the first element in this distribution's domain (most useful if there's only one element)
"""
return next(iter(self.domain()))

def get(self,element):
key = hash(element)
return dict.get(self,key,0.)

def __getitem__(self,element):
key = hash(element)
return dict.__getitem__(self,key)

def __setitem__(self,element,value):
"""
:param element: the domain element
:param value: the probability to associate with the given key
:type value: float
"""
key = hash(element)
self._domain[key] = element
dict.__setitem__(self,key,value)

def items(self):
for key,value in dict.items(self):
yield self._domain[key],value

"""
Utility method that increases the probability of the given element by the given value
"""
key = hash(element)
if key in self._domain:
dict.__setitem__(self,key,dict.__getitem__(self,key)+value)
else:
self._domain[key] = element
dict.__setitem__(self,key,value)
#        try:
#            self[element] += value
#        except KeyError:
#            self[element] = value

def getProb(self,element):
"""
Utility method that is almost identical to __getitem__, except that it returns 0 for missing elements, instead of throwing a C{KeyError}
"""
try:
return self[element]
except KeyError:
return 0.

def __delitem__(self,element):
key = hash(element)
dict.__delitem__(self,key)
del self._domain[key]

def clear(self):
dict.clear(self)
self._domain.clear()

def replace(self,old,new):
"""Replaces on element in the sample space with another.  Raises an exception if the original element does not exist, and an exception if the new element already exists (i.e., does not do a merge)
"""
prob = self[old]
del self[old]
self[new] = prob

def domain(self):
"""
:returns: the sample space of this probability distribution
:rtype: list
"""
return list(self._domain.values())

def normalize(self):
"""Normalizes the distribution so that the sum of values = 1
"""
total = sum(self.values())
if abs(total-1.) > self.epsilon:
for key in self.domain():
try:
self[key] /= total
except ZeroDivisionError:
self[key] = 1./float(len(self))

def expectation(self):
"""
:returns: the expected value of this distribution
:rtype: float
"""
if len(self) == 1:
# Shortcut if no uncertainty
return self.domain()[0]
else:
total = None
for element in self.domain():
if total is None:
total = element*self[element]
else:
total += element*self[element]

def __float__(self):
return self.expectation()

def sample(self,quantify=False):
"""
:param quantify: if ``True``, also returns the amount of mass by which the sampling crosssed the threshold of the generated sample's range
:returns: an element from this domain, with a sample probability given by this distribution
"""
import random
selection = random.uniform(0.,sum(self.values()))
original = selection
for element in self.domain():
if selection > self[element]:
selection -= self[element]
else:
if quantify:
return element,selection
else:
return element
# We shouldn't get here. But in case of some floating-point weirdness?
return element

def set(self,element):
"""
Reduce distribution to be 100% for the given element
:param element: the element that will be the only one with nonzero probability
"""
self.clear()
self[element] = 1.

def select(self,maximize=False):
"""
Reduce distribution to a single element, sampled according to the given distribution
:returns: the probability of the selection made
"""
if maximize:
element = self.max()
else:
element = self.sample()
prob = self[element]
self.set(element)
return prob

def max(self):
"""
:returns: the most probable element in this distribution (breaking ties by returning the highest-valued element)
"""
return self._domain[max([(dict.__getitem__(self,element),element) for element in self._domain])[1]]

def entropy(self):
"""
:returns: entropy (in bits) of this distribution
"""
return sum([-p*math.log2(p) for p in dict.values(self)])

if isinstance(other,Distribution):
result = self.__class__()
for me in self.domain():
for you in other.domain():
return result
else:
result = self.__class__()
for element in self.domain():
return result

def __sub__(self,other):
return self + (-other)

def __neg__(self):
result = self.__class__()
for element in self.domain():
return result

def __mul__(self,other):
if isinstance(other,Distribution):
raise NotImplementedError('Unable to multiply %s by %s.' \
% (self.__class__.__name__,other.__class__.__name__))
else:
result = self.__class__()
for element in self.domain():
return result

def prune(self,epsilon=1e-8):
elements = self.domain()
i = 0
while i < len(self)-1:
el1 = elements[i]
j = i+1
while j < len(self):
el2 = elements[j]
if abs(el1-el2) < epsilon:
self[el1] += self[el2]
del self[el2]
del elements[j]
else:
j += 1
i += 1

def __xml__(self):
"""
:returns: An XML Document object representing this distribution
"""
doc = Document()
root = doc.createElement('distribution')
doc.appendChild(root)
for key,value in self._domain.items():
prob = dict.__getitem__(self,key)
node = doc.createElement('entry')
root.appendChild(node)
node.setAttribute('probability',str(prob))
#            if key != hash(value):
#                node.setAttribute('key',key)
if isinstance(value,str):
node.setAttribute('key',key)
else:
node.appendChild(self.element2xml(value))
return doc

def element2xml(self,value):
raise NotImplementedError('Unable to generate XML for distributions over %s' % (value.__class__.__name__))

def parse(self,element):
"""Extracts the distribution from the given XML element
:param element: The XML Element object specifying the distribution
:type element: Element
:returns: This L{Distribution} object"""
assert element.tagName == 'distribution','Unexpected tag %s for %s' \
% (element.tagName,self.__class__.__name__)
self.clear()
node = element.firstChild
while node:
if node.nodeType == node.ELEMENT_NODE:
prob = float(node.getAttribute('probability'))
value = str(node.getAttribute('key'))
if not value:
subNode = node.firstChild
while subNode and subNode.nodeType != subNode.ELEMENT_NODE:
subNode = subNode.nextSibling
value = self.xml2element(None,subNode)
self[value] = prob
#                if not key:
#                    key = str(value)
#                dict.__setitem__(self,key,prob)
#                self._domain[key] = value
node = node.nextSibling

def xml2element(self,key,node):
return key

def sortedString(self):
elements = self.domain()
elements.sort(lambda x,y: cmp(str(x),str(y)))
return '\n'.join(['%4.1f%%\t%s' % (100.*self[el],str(el)) for el in elements])

def __str__(self):
return '\n'.join(['%d%%\t%s' % (100*self[el],str(el).replace('\n','\n\t'))
for el in self._domain.values()])
#        return '\n'.join(map(lambda el: '%d%%\t%s' % (100.*self[el],str(el).replace('\n','\n\t')),self.domain()))

def __hash__(self):
return hash(str(self))

def __copy__(self):
return self.__class__(self.__xml__().documentElement)

def __getstate__(self):
return {el: self[el] for el in self.domain()}

def __setstate__(self,state):
self.clear()
for el,prob in state.items():
self[el] = prob
```