""" A basic implementation of the "visitor pattern" for Python, using decorators. Inspiration from http://chris-lamb.co.uk/2006/12/08/visitor-pattern-in-python/ which contains a description of such a system but no implementation code that I could find And http://www.artima.com/weblogs/viewpost.jsp?thread=101605 in which Guido van Rossum shows how dynamic dispatch ("multi-methods") can be done. Please don't use this module for evil messes :), only use it when a visitor pattern will actually eliminate lots of nasty boilerplate code. Simple example: @when(str) def myFunc(arg): print "My string has length %d" % (len(arg)) @when(int) def myFunc(arg): print "My integer is %d" % arg myFunc("XYZ") My string has length 3 myFunc(12) My integer is 12 Advanced/Silly features: Optionally, multiple overloads can be called in sequence when visiting a class hierarchy. So, if you have annotated visitor methods for both a superclass and a subclass, the visitor pattern will optionally call first the superclass and then the subclass. The result of such "cascaded" superclass calls is discarded. Example: class Superclass: def __init__(self): pass class SubA(Superclass): def __init__(self): pass class SubB(Superclass): def __init__(self): pass @is_visitor class MyVisitor: @when(Superclass, allow_cascaded_calls=True) def foo(self, s): print "I am a superclass call" return 1 @when(SubA) def foo(self, a): print "I am a subclass A call" return 2 v=MyVisitor() a=SubA() b=SubB() Output: v.foo(a) I am a superclass call I am a subclass call 2 v.foo(b) I am a superclass call 1 Copyright 2011 Angus Gratton. Licensed under New BSD License as described in the file LICENSE. """ import inspect, types # these are the "current" overloaded methods, the @is_visitor annotation # will blat them out again after the current class is defined (allowing multiple # classes to have mixed overloads with the same method name!) # # dict key is function name, value is a method_overload with all the functions named that # in the class _methods = {} class when(object): """ Annotation indicating a method is a dynamic dispatch overload. Argument is the type of the first function argument,which will be used for dynamic dispatch. Arguments: argtype - this method should be invoked only with a first argument that matches (or subclasses) this type allow_cascaded_calls - this method will also be invoked if (default True) """ def __init__(self, argtype, allow_cascaded_calls=False): self.argtype = argtype self.allow_cascade=allow_cascaded_calls def __call__(self, func): #print "assigning %s to func %s" % (self.argtype, func) self.func_name =func.__name__ if not self.func_name in _methods: _methods[self.func_name] = method_overload(self.func_name) _methods[self.func_name].register(self.argtype, func, self.allow_cascade) return _methods[self.func_name] class method_overload(object): """ This is the actual overload information for a particular method name on a particular class (or no class.) These are internal, created at module import time. """ def __init__(self, func_name): self.func_name = func_name self.registry = {} self.allow_cascade = {} def register(self, argtype, func, allow_cascade): self.registry[argtype] = func self.allow_cascade[argtype] = allow_cascade def __get__(self, obj, type=None): """ This __get__ is called when the method is bound on a class, and returns a bound_caller which knows the instance and the class type. If the method is not bound on a class, this is skipped and __call__ is called directly. """ return bound_caller(self, obj, type) def __call__(self, *args, **kw): """ This __call__ is only reached when the method is not bound to a class """ return self.call_internal(lambda f:f, args, kw) def call_internal(self, func_modifier, args, kw): """ Common utility class for calling an overloaded method, either bound on a class or not. func_modifier is a lambda function which is used to "bind" bound methods to the correct instance """ argtype = type(args[0]) class Old: pass if argtype is types.InstanceType: # old-style class argtype = args[0].__class__ hier = list(inspect.getmro(argtype)) # class hierarchy hier.reverse() # order w/ superclass first hier = [ t for t in hier if t in self.registry ] if len(hier) == 0: raise TypeError("Function %s has no compatible overloads registered for argument type %s" % (self.func_name, argtype)) result = None for t in hier: if not self.allow_cascade[t] and t != hier[-1]: continue # don't "cascade" down from superclass on this method result = func_modifier(self.registry[t])(*args, **kw) return result class bound_caller(object): """ Temporary class instantiated once per method call(!!) for methods bound to a class, contains the bound instance and its class type. (Implementation note: It may be possible to replace this with a nested function or a lambda function, not sure if this would change overhead?) """ def __init__(self, overload, obj, cls): self.overload = overload self.obj = obj self.cls = cls def __call__(self, *args, **kw): return self.overload.call_internal(lambda l:l.__get__(self.obj, self.cls), args, kw) def is_visitor(cls): """ Decorator to mark any class which contains one or more @when annotations. This is needed if more than one class/scope in the module contains an overloaded method with the same name. If it is missing from a class, the methods on the next class will not be held distinct from the methods on this class. """ global _methods _methods = {} return cls