import math from xml.dom.minidom import Document,Node class Distribution(dict): """ A probability distribution over hashable objects .. warning:: If you make the domain values mutable types, try not to change their values while they are inside the distribution. If you must change a domain value, it is better to first delete the old value, change it, and then re-insert it. """ epsilon = 1e-8 def __new__(cls,args=None,rationality=None): obj = dict.__new__(cls) obj._domain = {} return obj def __init__(self,args=None,rationality=None): """ :param args: the initial elements of the probability distribution :type args: dict :param rationality: if not ``None``, then use as a rationality parameter in a quantal response over the provided values :type rationality: float """ dict.__init__(self) # self._domain = {} if isinstance(args,Node): self.parse(args) elif isinstance(args,Distribution): # Some other distribution given for key in args.domain(): self[key] = args[key] elif isinstance(args,dict): if rationality is None: # Probability dictionary provided for key,value in args.items(): self[key] = value else: # Do quantal response / softmax on table of values for key,V in args.items(): self[key] = math.exp(rationality*V) self.normalize() def first(self): """ :returns: the first element in this distribution's domain (most useful if there's only one element) """ return next(iter(self.domain())) def get(self,element): key = hash(element) return dict.get(self,key,0.) def __getitem__(self,element): key = hash(element) return dict.__getitem__(self,key) def __setitem__(self,element,value): """ :param element: the domain element :param value: the probability to associate with the given key :type value: float """ key = hash(element) self._domain[key] = element dict.__setitem__(self,key,value) def items(self): for key,value in dict.items(self): yield self._domain[key],value def addProb(self,element,value): """ Utility method that increases the probability of the given element by the given value """ key = hash(element) if key in self._domain: dict.__setitem__(self,key,dict.__getitem__(self,key)+value) else: self._domain[key] = element dict.__setitem__(self,key,value) # try: # self[element] += value # except KeyError: # self[element] = value def getProb(self,element): """ Utility method that is almost identical to __getitem__, except that it returns 0 for missing elements, instead of throwing a C{KeyError} """ try: return self[element] except KeyError: return 0. def __delitem__(self,element): key = hash(element) dict.__delitem__(self,key) del self._domain[key] def clear(self): dict.clear(self) self._domain.clear() def replace(self,old,new): """Replaces on element in the sample space with another. Raises an exception if the original element does not exist, and an exception if the new element already exists (i.e., does not do a merge) """ prob = self[old] del self[old] self[new] = prob def domain(self): """ :returns: the sample space of this probability distribution :rtype: list """ return list(self._domain.values()) def normalize(self): """Normalizes the distribution so that the sum of values = 1 """ total = sum(self.values()) if abs(total-1.) > self.epsilon: for key in self.domain(): try: self[key] /= total except ZeroDivisionError: self[key] = 1./float(len(self)) def expectation(self): """ :returns: the expected value of this distribution :rtype: float """ if len(self) == 1: # Shortcut if no uncertainty return self.domain()[0] else: total = None for element in self.domain(): if total is None: total = element*self[element] else: total += element*self[element] return total def __float__(self): return self.expectation() def sample(self,quantify=False): """ :param quantify: if ``True``, also returns the amount of mass by which the sampling crosssed the threshold of the generated sample's range :returns: an element from this domain, with a sample probability given by this distribution """ import random selection = random.uniform(0.,sum(self.values())) original = selection for element in self.domain(): if selection > self[element]: selection -= self[element] else: if quantify: return element,selection else: return element # We shouldn't get here. But in case of some floating-point weirdness? return element def set(self,element): """ Reduce distribution to be 100% for the given element :param element: the element that will be the only one with nonzero probability """ self.clear() self[element] = 1. def select(self,maximize=False): """ Reduce distribution to a single element, sampled according to the given distribution :returns: the probability of the selection made """ if maximize: element = self.max() else: element = self.sample() prob = self[element] self.set(element) return prob def max(self): """ :returns: the most probable element in this distribution (breaking ties by returning the highest-valued element) """ return self._domain[max([(dict.__getitem__(self,element),element) for element in self._domain])[1]] def entropy(self): """ :returns: entropy (in bits) of this distribution """ return sum([-p*math.log2(p) for p in dict.values(self)]) def __add__(self,other): if isinstance(other,Distribution): result = self.__class__() for me in self.domain(): for you in other.domain(): result.addProb(me+you,self[me]*other[you]) return result else: result = self.__class__() for element in self.domain(): result.addProb(element+other,self[element]) return result def __sub__(self,other): return self + (-other) def __neg__(self): result = self.__class__() for element in self.domain(): result.addProb(-element,self[element]) return result def __mul__(self,other): if isinstance(other,Distribution): raise NotImplementedError('Unable to multiply %s by %s.' \ % (self.__class__.__name__,other.__class__.__name__)) else: result = self.__class__() for element in self.domain(): result.addProb(element*other,self[element]) return result def prune(self,epsilon=1e-8): elements = self.domain() i = 0 while i < len(self)-1: el1 = elements[i] j = i+1 while j < len(self): el2 = elements[j] if abs(el1-el2) < epsilon: self[el1] += self[el2] del self[el2] del elements[j] else: j += 1 i += 1 def __xml__(self): """ :returns: An XML Document object representing this distribution """ doc = Document() root = doc.createElement('distribution') doc.appendChild(root) for key,value in self._domain.items(): prob = dict.__getitem__(self,key) node = doc.createElement('entry') root.appendChild(node) node.setAttribute('probability',str(prob)) # if key != hash(value): # node.setAttribute('key',key) if isinstance(value,str): node.setAttribute('key',key) else: node.appendChild(self.element2xml(value)) return doc def element2xml(self,value): raise NotImplementedError('Unable to generate XML for distributions over %s' % (value.__class__.__name__)) def parse(self,element): """Extracts the distribution from the given XML element :param element: The XML Element object specifying the distribution :type element: Element :returns: This L{Distribution} object""" assert element.tagName == 'distribution','Unexpected tag %s for %s' \ % (element.tagName,self.__class__.__name__) self.clear() node = element.firstChild while node: if node.nodeType == node.ELEMENT_NODE: prob = float(node.getAttribute('probability')) value = str(node.getAttribute('key')) if not value: subNode = node.firstChild while subNode and subNode.nodeType != subNode.ELEMENT_NODE: subNode = subNode.nextSibling value = self.xml2element(None,subNode) self[value] = prob # if not key: # key = str(value) # dict.__setitem__(self,key,prob) # self._domain[key] = value node = node.nextSibling def xml2element(self,key,node): return key def sortedString(self): elements = self.domain() elements.sort(lambda x,y: cmp(str(x),str(y))) return '\n'.join(['%4.1f%%\t%s' % (100.*self[el],str(el)) for el in elements]) def __str__(self): return '\n'.join(['%d%%\t%s' % (100*self[el],str(el).replace('\n','\n\t')) for el in self._domain.values()]) # return '\n'.join(map(lambda el: '%d%%\t%s' % (100.*self[el],str(el).replace('\n','\n\t')),self.domain())) def __hash__(self): return hash(str(self)) def __copy__(self): return self.__class__(self.__xml__().documentElement) def __getstate__(self): return {el: self[el] for el in self.domain()} def __setstate__(self,state): self.clear() for el,prob in state.items(): self[el] = prob