#========================================================================= # BehavioralRTLIRGenL1Pass.py #========================================================================= # Author : Peitian Pan # Date : Oct 20, 2018 """Provide L1 behavioral RTLIR generation pass.""" import ast import copy import pymtl3.dsl as dsl from pymtl3 import MetadataKey from pymtl3.datatypes import ( Bits, concat, is_bitstruct_class, is_bitstruct_inst, reduce_and, reduce_or, reduce_xor, sext, trunc, zext, ) from pymtl3.passes.rtlir.errors import PyMTLSyntaxError from pymtl3.passes.rtlir.RTLIRPass import RTLIRPass from pymtl3.passes.rtlir.rtype.RTLIRType import RTLIRGetter from pymtl3.passes.rtlir.util.utility import get_ordered_upblks, get_ordered_update_ff from . import BehavioralRTLIR as bir class BehavioralRTLIRGenL1Pass( RTLIRPass ): # Pass metadata #: A dictionary that maps upblk functions to their BIR representation #: #: Type: ``dict``; output rtlir_upblks = MetadataKey() def __init__( s, translation_top ): c = s.__class__ s.tr_top = translation_top if not translation_top.has_metadata( c.rtlir_getter ): translation_top.set_metadata( c.rtlir_getter, RTLIRGetter(cache=True) ) def __call__( s, m ): """Generate RTLIR for all upblks of m.""" c = s.__class__ if m.has_metadata( c.rtlir_upblks ): rtlir_upblks = m.get_metadata( c.rtlir_upblks ) else: rtlir_upblks = {} m.set_metadata( c.rtlir_upblks, rtlir_upblks ) visitor = s.get_rtlir_generator_class()( m ) upblks = { bir.CombUpblk : get_ordered_upblks(m), bir.SeqUpblk : get_ordered_update_ff(m), } # Sort the upblks by their name upblks[bir.CombUpblk].sort( key = lambda x: x.__name__ ) upblks[bir.SeqUpblk ].sort( key = lambda x: x.__name__ ) for upblk_type in ( bir.CombUpblk, bir.SeqUpblk ): for blk in upblks[ upblk_type ]: visitor._upblk_type = upblk_type upblk_info = m.get_update_block_info( blk ) upblk = visitor.enter( blk, upblk_info[-1] ) upblk.is_lambda = upblk_info[0] upblk.src = upblk_info[1] upblk.lino = upblk_info[2] upblk.filename = upblk_info[3] rtlir_upblks[ blk ] = upblk def get_rtlir_generator_class( s ): return BehavioralRTLIRGeneratorL1 class BehavioralRTLIRGeneratorL1( ast.NodeVisitor ): def __init__( s, component ): s.component = component def enter( s, blk, ast ): """Entry point of RTLIR generation.""" s.blk = blk # s.globals contains a dict of the global namespace of the module where # blk was defined s.globals = blk.__globals__ # s.closure contains the free variables defined in an enclosing scope. # Basically this is the model instance s. s.closure = {} for i, var in enumerate( blk.__code__.co_freevars ): try: s.closure[ var ] = blk.__closure__[ i ].cell_contents except ValueError: pass s.const_extractor = ConstantExtractor( s.blk, s.globals, s.closure ) ret = s.visit( ast ) ret.component = s.component return ret def handle_constant( s, node, obj ): if isinstance( obj, int ): return bir.Number( obj ) elif isinstance( obj, Bits ): return bir.SizeCast( obj.nbits, bir.Number( obj.uint() ) ) else: return None def get_call_obj( s, node ): if hasattr(node, "starargs") and node.starargs: raise PyMTLSyntaxError( s.blk, node, 'star argument is not supported!') if hasattr(node, "kwargs") and node.kwargs: raise PyMTLSyntaxError( s.blk, node, 'double-star argument is not supported!') if node.keywords: raise PyMTLSyntaxError( s.blk, node, 'keyword argument is not supported!') obj = s.const_extractor.enter( node.func ) if obj is not None: return obj else: raise PyMTLSyntaxError( s.blk, node, f'{node.func} function is not found!' ) def visit_Module( s, node ): if len( node.body ) != 1 or \ not isinstance( node.body[0], ast.FunctionDef ): raise PyMTLSyntaxError( s.blk, node, 'Update blocks should have exactly one FuncDef!' ) ret = s.visit( node.body[0] ) ret.ast = node return ret def visit_FunctionDef( s, node ): """Return the behavioral RTLIR of function node. We do not need to check the decorator list -- the fact that we are visiting this node ensures this node was added to the upblk dictionary through update() (or other PyMTL decorators) earlier! """ # Check the arguments of the function if node.args.args or node.args.vararg or node.args.kwarg: raise PyMTLSyntaxError( s.blk, node, 'Update blocks should not have arguments!' ) # Save the name of the upblk s._upblk_name = node.name # Construct the node using the type of upblk ret = s._upblk_type( node.name, [] ) for stmt in node.body: ret.body.append( s.visit( stmt ) ) ret.ast = node return ret def visit_Assign( s, node ): if len( node.targets ) < 1: raise PyMTLSyntaxError( s.blk, node, 'At least one assignment target should be provided!' ) value = s.visit( node.value ) targets = [ s.visit( target ) for target in node.targets ] ret = bir.Assign( targets, value, False ) # Need a handle to bir node # Determine if this is a blocking/non-blocking assignment ret.blocking = s.get_blocking(node, ret) ret.ast = node return ret def get_blocking( s, node, bir_node ): return s._upblk_type is bir.CombUpblk def visit_AugAssign( s, node ): """Return the behavioral RTLIR of a non-blocking assignment If the given AugAssign is not @= or <<=, throw PyMTLSyntaxError """ if isinstance( node.op, (ast.LShift, ast.MatMult) ): value = s.visit( node.value ) targets = [ s.visit( node.target ) ] blocking = False if isinstance(node.op, ast.LShift) else True ret = bir.Assign( targets, value, blocking ) ret.ast = node return ret raise PyMTLSyntaxError( s.blk, node, 'invalid operation: augmented assignment is not @= or <<= assignment!' ) def visit_Call( s, node ): """Return the behavioral RTLIR of method calls. Some data types are interpreted as function calls in the Python AST. Example: Bits4(2) These are converted to different RTLIR nodes in different contexts. """ num_args = len( node.args ) obj = s.get_call_obj( node ) if ( obj == copy.copy ) or ( obj == copy.deepcopy ): if num_args != 1: raise PyMTLSyntaxError( s.blk, node, f'copy method {obj} takes exactly 1 argument!') ret = s.visit( node.args[0] ) ret.ast = node return ret # Now that we have the live Python object, there are a few cases that # we need to treat separately: # 1. Instantiation: Bits16( 10 ) where obj is an instance of Bits # Bits16( 1+2 ), Bits16( s.STATE_A )? # 2. concat() # 3. zext(), sext() # TODO: support the following # 4. reduce_and(), reduce_or(), reduce_xor() # 5. Real function call: not supported yet # Deal with Bits type cast if isinstance(obj, type) and issubclass( obj, Bits ): nbits = obj.nbits if num_args > 1: raise PyMTLSyntaxError( s.blk, node, 'exactly one or zero argument should be given to Bits!' ) if num_args == 0: ret = bir.SizeCast( nbits, bir.Number( 0 ) ) else: ret = bir.SizeCast( nbits, s.visit( node.args[0] ) ) # concat method elif obj is concat: if num_args < 1: raise PyMTLSyntaxError( s.blk, node, 'at least one argument should be given to concat!' ) values = [s.visit(c) for c in node.args] ret = bir.Concat( values ) # zext method elif obj is zext: if num_args != 2: raise PyMTLSyntaxError( s.blk, node, 'exactly two arguments should be given to zext!' ) nbits = s.const_extractor.enter( node.args[1] ) if isinstance(nbits, type) and issubclass( nbits, Bits ): nbits = nbits.nbits if not isinstance( nbits, int ): raise PyMTLSyntaxError( s.blk, node, 'the 2nd argument of zext {nbits} is not a constant int or BitsN type!' ) ret = bir.ZeroExt( nbits, s.visit( node.args[0] ) ) # sext method elif obj is sext: if num_args != 2: raise PyMTLSyntaxError( s.blk, node, 'exactly two arguments should be given to sext!' ) nbits = s.const_extractor.enter( node.args[1] ) if isinstance(nbits, type) and issubclass( nbits, Bits ): nbits = nbits.nbits if not isinstance( nbits, int ): raise PyMTLSyntaxError( s.blk, node, 'the 2nd argument of sext {nbits} is not a constant int or BitsN type!' ) ret = bir.SignExt( nbits, s.visit( node.args[0] ) ) # trunc method elif obj is trunc: if num_args != 2: raise PyMTLSyntaxError( s.blk, node, 'exactly two arguments should be given to trunc!' ) nbits = s.const_extractor.enter( node.args[1] ) if isinstance(nbits, type) and issubclass( nbits, Bits ): nbits = nbits.nbits if not isinstance( nbits, int ): raise PyMTLSyntaxError( s.blk, node, 'the 2nd argument of trunc {nbits} is not a constant int or BitsN type!' ) ret = bir.Truncate( nbits, s.visit( node.args[0] ) ) # reduce methods elif obj is reduce_and or obj is reduce_or or obj is reduce_xor: if obj is reduce_and: op = bir.BitAnd() elif obj is reduce_or: op = bir.BitOr() elif obj is reduce_xor: op = bir.BitXor() if num_args != 1: raise PyMTLSyntaxError( s.blk, node, f'exactly two arguments should be given to reduce {op} methods!' ) ret = bir.Reduce( op, s.visit( node.args[0] ) ) else: # Only Bits class instantiation is supported at L1 raise PyMTLSyntaxError( s.blk, node, f'Unrecognized method call {obj.__name__}!' ) ret.ast = node return ret def visit_Attribute( s, node ): obj = s.const_extractor.enter( node ) ret = s.handle_constant( node, obj ) if ret: return ret ret = bir.Attribute( s.visit( node.value ), node.attr ) ret.ast = node return ret def visit_Subscript( s, node ): obj = s.const_extractor.enter( node ) ret = s.handle_constant( node, obj ) if ret: return ret value = s.visit( node.value ) if isinstance( node.slice, ast.Slice ): if node.slice.step is not None: raise PyMTLSyntaxError( s.blk, node, 'Slice with steps is not supported!' ) lower, upper = s.visit( node.slice ) ret = bir.Slice( value, lower, upper ) ret.ast = node return ret # signal[ index ] # index might be a slice object! if isinstance( node.slice, ast.Index ): idx = s.visit( node.slice ) # If we have a static slice object then use it if isinstance( idx, bir.FreeVar ) and isinstance( idx.obj, slice ): slice_obj = idx.obj if slice_obj.step is not None: raise PyMTLSyntaxError( s.blk, node, 'Slice with steps is not supported!' ) assert isinstance( slice_obj.start, int ) and \ isinstance( slice_obj.stop, int ), \ f"start and stop of slice object {slice_obj} must be integers!" ret = bir.Slice( value, bir.Number(slice_obj.start), bir.Number(slice_obj.stop) ) # Else this is a real index else: ret = bir.Index( value, idx ) ret.ast = node return ret raise PyMTLSyntaxError( s.blk, node, 'Illegal subscript ' + node + ' encountered!' ) def visit_Slice( s, node ): return ( s.visit( node.lower ), s.visit( node.upper ) ) def visit_Index( s, node ): return s.visit( node.value ) def visit_Name( s, node ): if node.id in s.closure: # free var from closure obj = s.closure[ node.id ] if isinstance( obj, dsl.Component ): # Component freevars are an L1 thing. if obj is not s.component: raise PyMTLSyntaxError( s.blk, node, f'Component {obj} is not a sub-component of {s.component}!' ) ret = bir.Base( obj ) else: # A closure variable could be a loop index. We need to # generate per-function closure variable instead of assuming # they will have the same value. ret = bir.FreeVar( f"{node.id}_at_{s.blk.__name__}", obj ) ret.ast = node return ret elif node.id in s.globals: # free var from the global name space # For now we can still safely assume all upblks will see the same # value for a free var from the global space? ret = bir.FreeVar( node.id, s.globals[ node.id ] ) ret.ast = node return ret raise PyMTLSyntaxError( s.blk, node, f'Temporary variable {node.id} is not supported at L1!' ) def visit_Num( s, node ): ret = bir.Number( node.n ) ret.ast = node return ret def visit_If( s, node ): raise NotImplementedError() def visit_For( s, node ): raise NotImplementedError() def visit_BoolOp( s, node ): raise NotImplementedError() def visit_BinOp( s, node ): raise NotImplementedError() def visit_UnaryOp( s, node ): raise NotImplementedError() def visit_IfExp( s, node ): raise NotImplementedError() def visit_Compare( s, node ): raise NotImplementedError() # $display def visit_Print( s, node ): raise NotImplementedError() # function def visit_Return( s, node ): raise NotImplementedError() # SV assertion def visit_Assert( s, node ): raise NotImplementedError() def visit_Expr( s, node ): """Return the behavioral RTLIR of an expression. ast.Expr might be useful when a statement is only a call to a task or a non-returning function. """ raise PyMTLSyntaxError( s.blk, node, 'Stand-alone expression is not supported yet!' ) def visit_Lambda( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: lambda function' ) def visit_Dict( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid type: dict' ) def visit_Set( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid type: set' ) def visit_List( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid type: list' ) def visit_Tuple( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid type: tuple' ) def visit_ListComp( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: list comprehension' ) def visit_SetComp( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: set comprehension' ) def visit_DictComp( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: dict comprehension' ) def visit_GeneratorExp( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: generator expression' ) def visit_Yield( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: yield' ) def visit_Repr( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: repr' ) def visit_Str( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: str' ) def visit_ClassDef( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: classdef' ) def visit_Delete( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: delete' ) def visit_With( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: with' ) def visit_Raise( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: raise' ) def visit_TryExcept( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: try-except' ) def visit_TryFinally( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: try-finally' ) def visit_Import( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: import' ) def visit_ImportFrom( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: import-from' ) def visit_Exec( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: exec' ) def visit_Global( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: global' ) def visit_Pass( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: pass' ) def visit_Break( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: break' ) def visit_Continue( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: continue' ) def visit_While( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: while' ) def visit_ExtSlice( s, node ): raise PyMTLSyntaxError( s.blk, node, 'invalid operation: extslice' ) class ConstantExtractor( ast.NodeVisitor ): def __init__( s, blk, global_ns, closure_ns ): s.blk = blk s.globals = global_ns s.cache = {} s.closure = closure_ns s.pymtl_functions = { concat, sext, zext, trunc, reduce_or, reduce_and, reduce_xor, copy.copy, copy.deepcopy } def generic_visit( s, node ): return None def enter( s, node ): ret = s.visit( node ) # Constant objects that are recognized # 1. int, BitsN( X ) # 2. BitsN # 3. BitStruct, BitStruct() # 4. Functions, including concat, zext, sext, etc. is_value = isinstance(ret, (int, Bits)) or is_bitstruct_inst(ret) is_type = isinstance(ret, type) and (issubclass(ret, Bits) or is_bitstruct_class(ret)) try: is_function = ret in s.pymtl_functions except: is_function = False if is_value or is_type or is_function: return ret else: return None def visit_Attribute( s, node ): if node in s.cache: return s.cache[node] value = s.visit( node.value ) try: ret = getattr( value, node.attr ) except AttributeError: ret = None s.cache[node] = ret return ret def visit_Subscript( s, node ): if node in s.cache: return s.cache[node] ret = None if isinstance( node.slice, ast.Index ): value = s.visit( node.value ) idx = s.visit( node.slice ) if value is not None and idx is not None: try: ret = value[idx] except: ret = None s.cache[node] = ret return ret def visit_Index( s, node ): if node in s.cache: return s.cache[node] ret = s.visit( node.value ) s.cache[node] = ret return ret def visit_Name( s, node ): name = node.id if name in s.closure: # free var from closure obj = s.closure[ name ] elif name in s.globals: # free var from the global name space obj = s.globals[ name ] else: obj = None return obj def visit_Num( s, node ): return node.n