"""Operator Loader.""" from collections import namedtuple import os import sys import re import time import ctypes import json import warnings import portalocker from ..edict import edict from ..func import CFuncDef, bind, get_func_idcode, get_idcode_hash from ..build import config, source_to_so_ctx, build_context, file_is_changed, ENV_PATH from ..utils import get_git_hash, makedirs from ..dtype import DType, TemplateType from ..version import OP_LOAD_MODULE_BUILD_VERSION from ..glue.backend import get_glue_modules from .gen_code import get_gen_rel_code gen_code = get_gen_rel_code(os.path.dirname(__file__)) if sys.version_info[0] >= 3: import importlib.util def load_module(name, pathname): """Load Module. Paramters --------- name: str the name of module. pathname: the name of path. Returns ------- Module """ spec = importlib.util.spec_from_file_location(name, pathname) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module else: import imp def load_module(name, pathname): """Load Module. Paramters --------- name: str the name of module. pathname: the name of path. Returns ------- Module """ module = imp.load_source(name, pathname) return module def _get_func_head_reg(name): """Get a pattern object for CFunction Head. Paramters --------- name: str Function name. Returns ------- A pattern object """ return re.compile(r'^\s*{}\s*(.*)'.format(name)) MOBULA_KERNEL_REG = _get_func_head_reg('MOBULA_(KERNEL|FUNC)') FUNC_REG = re.compile( r'^\s*(.*?)\s*\((.*?)\)(?:.*?)*') CPP_TEMPLATE_REG = re.compile(r'^\s*template\s*\<(.*?)\>\s*') def _get_template_decl(code): match = CPP_TEMPLATE_REG.search(code) if match is None: return None blocks = match.groups()[0].split(',') templates = [] for block in blocks: block_sp = block.split() dtype, dname = block_sp if dtype.strip() == 'typename': templates.append(dname.strip()) return templates def parse_parameter_decl(decl): """Parse the code of parameter declaration Parameters ---------- decl : str The C++ code of parameter declaration Returns ------- Tuple (DType Instance, variable name) """ num_star = decl.count('*') assert num_star <= 1,\ Exception('Only support pass-by-value or pass-by-1-level-pointer, \ Error declaration: {}'.format(decl)) is_pointer = num_star > 0 if is_pointer: decl = decl.replace('*', '') decl = decl.strip() if decl.startswith('const '): is_const = True decl = decl[len('const '):] else: is_const = False decl_sp = decl.split(' ') # type_name and variable_name in C++ code type_name, var_name = decl_sp # void* func(...) if type_name == 'void': assert is_pointer return DType(ctypes.c_void_p, is_const=is_const), var_name # ctype func(...) ctype_name = 'c_{}'.format(type_name) if hasattr(ctypes, ctype_name): ctype = getattr(ctypes, ctype_name) if is_pointer: ctype = ctypes.POINTER(ctype) return DType(ctype, is_const=is_const), var_name # template type return TemplateType(tname=type_name, is_pointer=is_pointer, is_const=is_const), var_name def parse_parameters_list(plist): """Parse the code of parameter declaration list Parameters ---------- plist : str The code of parameter declaration list Returns ------- rtn_type : The type of return value func_name : str function name pars_list: list [(DType|TemplateType, variable name), ...] """ match = FUNC_REG.search(plist) head, plist = match.groups() head_split = re.split(r'\s+', head) plist_split = re.split(r'\s*,\s*', plist) func_name = head_split[-1] rtn_type = head_split[-2] if len(head_split) == 3 else None pars_list = list(map(parse_parameter_decl, plist_split)) return rtn_type, func_name, pars_list # runtime FuncInfo = namedtuple('FuncInfo', ['func', 'cpp_info']) CTX_FUNC_MAP = dict() # CTX_FUNC_MAP[ctx][cpp_fname] -> FuncInfo class CPPInfo: """The class of the C++ file's information. Parameters ---------- cpp_fname: str the filename of C++ file. """ def __init__(self, cpp_fname): self.cpp_fname = cpp_fname self.function_args = dict() self.dll = None def load_dll(self, dll_fname): """Load Dynamic-Link Library(*.so or *.dll). Parameters ---------- dll_fname: The name of Dynamic-Link Library. """ # keep reference self.dll = ctypes.CDLL(dll_fname) def _build_lib(cpp_fname, code_buffer, ctx, target_name): cpp_path, cpp_basename = os.path.split(cpp_fname) build_path = os.path.join(cpp_path, 'build') create_time = time.strftime('%a %Y-%m-%d %H:%M:%S (%z)', time.localtime()) git_hash = get_git_hash() extra_code = gen_code('./templates/header.cpp')( cpp_fname=cpp_fname, git_hash=git_hash, create_time=create_time, inc_fname=os.path.normpath(os.path.join('../..', cpp_basename)), code=code_buffer) build_path_ctx = os.path.join(build_path, ctx) makedirs(build_path_ctx, exist_ok=True) # build so cpp_wrapper_fname = os.path.join(build_path_ctx, os.path.splitext(cpp_basename)[0] + '_wrapper.cpp') with open(cpp_wrapper_fname, 'w') as fout: fout.write(extra_code) # build lib srcs = [cpp_wrapper_fname] source_to_so_ctx(build_path, srcs, target_name, ctx) def _dtype_to_tvm_value_type(dtype): if dtype.is_pointer: return 'v_handle' if 'int' in dtype.cname: return 'v_int64' return 'v_float64' def _get_args_inst_mx(i, t): s = 'args.values[%d].%s' % (i, _dtype_to_tvm_value_type(t)) if t.is_pointer: return ''' static_cast<{dtype}>( static_cast<DLTensor*>({tv})->data)'''.format(dtype=t.cname, tv=s) else: s = '\n ' + s return s def _generate_kernel_code(func_idcode_hash, arg_types, arg_names, func_name): args_def = ', '.join(['{ctype} {name}'.format( ctype=dtype.cname, name=name ) for dtype, name in zip(arg_types, arg_names)]) args_inst = ', '.join(arg_names) kernel_code = gen_code('./templates/kernel_code.cpp')( func_idcode_hash=func_idcode_hash, args_def=args_def, func_name=func_name, args_inst=args_inst) kernel_code += '\n' args_def_async_mx = ', '.join(['{ctype} {name}'.format( ctype='NDArrayHandle' if dtype.is_pointer else dtype.cname, name=name ) for dtype, name in zip(arg_types, arg_names)]) using_async_mx = all( map(lambda dtype: 'void' not in dtype.cname, arg_types)) if using_async_mx: args_inst_mx = [_get_args_inst_mx(i, t) for i, t in enumerate(arg_types)] const_loc = [] for i, dtype in enumerate(arg_types): if dtype.is_const and dtype.is_pointer: const_loc.append(i) num_const = len(const_loc) const_loc_code = 'nullptr' if num_const == 0 else 'std::array<int, %d>({%s}).data()' % ( num_const, ','.join([str(u) for u in const_loc])) async_mx_code = gen_code('./templates/async_mx_code.cpp')( func_idcode_hash=func_idcode_hash, func_name=func_name, args_inst=args_inst, args_inst_mx=','.join(args_inst_mx), num_const=num_const, const_loc_code=const_loc_code, args_def_async_mx=args_def_async_mx, ) async_mx_code += '\n' kernel_code += async_mx_code return kernel_code def _generate_func_code(func_idcode_hash, rtn_type, arg_types, arg_names, func_name): if rtn_type is None: rtn_type = 'void' args_def = ', '.join(['{ctype} {name}'.format( ctype=dtype.cname, name=name ) for dtype, name in zip(arg_types, arg_names)]) args_inst = ', '.join(arg_names) code = ''' MOBULA_DLL %s %s(%s) { ''' % (rtn_type, func_idcode_hash, args_def) if rtn_type != 'void': code += ' return ' code += '%s(%s);\n}\n' % (func_name, args_inst) return code def _generate_ordinary_code(cpp_info): code_buffer = '' # generate ordinary functions code for func_name, ord_cfunc in cpp_info.function_args.items(): if ord_cfunc.template_list: continue func_idcode = get_func_idcode(func_name, ord_cfunc.arg_types) func_idcode_hash = get_idcode_hash(func_idcode) func_kind = ord_cfunc.func_kind if func_kind == CFuncDef.KERNEL: code_buffer += _generate_kernel_code( func_idcode_hash, ord_cfunc.arg_types, ord_cfunc.arg_names, '{}_kernel'.format(func_name)) code_buffer += '\n' return code_buffer def _get_ordinary_functions(cpp_info): res = list() for func_name, ord_cfunc in cpp_info.function_args.items(): if ord_cfunc.template_list: continue func_idcode = get_func_idcode(func_name, ord_cfunc.arg_types) res.append(func_idcode) return res def _update_template_inst_map(idcode, template_functions, cfunc, arg_types): # template function func_name = cfunc.func_name func_idcode_hash = get_idcode_hash(idcode) # Check Template Type Mapping template_mapping = dict() for rtype, dtype in zip(arg_types, cfunc.arg_types): if not isinstance(dtype, TemplateType): continue tname = dtype.tname rtype = str(rtype).replace( 'const', '').replace('*', '').strip() if tname in template_mapping: assert template_mapping[tname] == rtype,\ Exception('Excepted template type {} instead of {}'. format(template_mapping[tname], rtype)) else: template_mapping[tname] = rtype assert len(template_mapping) == len(cfunc.template_list),\ Exception('Template List: {}, mapping: {}'. format(cfunc.template_list, template_mapping)) template_inst = [template_mapping[tname] for tname in cfunc.template_list] template_post = '<%s>' % (', '.join(template_inst)) rtn_type = cfunc.rtn_type if rtn_type in template_mapping: rtn_type = template_mapping[rtn_type] func_kind = cfunc.func_kind if func_kind == CFuncDef.KERNEL: code = _generate_kernel_code(func_idcode_hash, arg_types, cfunc.arg_names, '({}_kernel{})'.format( func_name, template_post)) else: code = _generate_func_code( func_idcode_hash, rtn_type, arg_types, cfunc.arg_names, func_name + template_post) template_functions[idcode] = code def _add_function(func_map, func_idcode, cpp_info, dll_fname): func_idcode_hash = get_idcode_hash(func_idcode) func = getattr(cpp_info.dll, func_idcode_hash, None) assert func is not None,\ Exception('No function `{}` in DLL {}'.format( func_idcode, dll_fname)) old_func = func_map.get(func_idcode, None) if old_func is not None: if old_func.cpp_info.cpp_fname != cpp_info.cpp_fname: warnings.warn('The function `{}` in `{}` will be overridden by that in `{}`'.format( func_idcode, old_func.cpp_info.cpp_fname, cpp_info.cpp_fname)) func_map[func_idcode] = FuncInfo(func=func, cpp_info=cpp_info) class OpLoader: '''Import Operator Loader. It's actual to load the operator. Parameters ---------- cfunc: CFuncDef The definition of function to call. arg_types: list of {DType|TemplateType} Argument declaration list. ctx: str Building context. cpp_info: CPPInfo Related to cfunc. Returns ------- CTX_FUNC_MAP[ctx][fname][idcode] : FuncInfo ''' def __init__(self, cfunc, arg_types, ctx, cpp_info): idcode = get_func_idcode(cfunc.func_name, arg_types) if ctx not in CTX_FUNC_MAP: CTX_FUNC_MAP[ctx] = dict() cpp_fname = cpp_info.cpp_fname if cpp_fname not in CTX_FUNC_MAP[ctx]: CTX_FUNC_MAP[ctx][cpp_fname] = dict() # func_map: dict mapping idcode to CFunction func_map = CTX_FUNC_MAP[ctx][cpp_fname] if idcode not in func_map: ''' *load function* when one of the following conditions is True: 1. idcode is not loaded 2. loading the function with same function name but different cpp filename ''' cpp_path, cpp_basename = os.path.split(cpp_fname) build_path = os.path.join(cpp_path, 'build') use_template = bool(cfunc.template_list) makedirs(build_path, exist_ok=True) build_info_fname = os.path.join( build_path, os.path.splitext(cpp_basename)[0] + '.json') build_info_fs = open(build_info_fname, 'a+') portalocker.lock(build_info_fs, portalocker.LOCK_EX) build_info_fs.seek(0) js_data = build_info_fs.read() if js_data: map_data = json.loads(js_data) else: map_data = dict(version=OP_LOAD_MODULE_BUILD_VERSION) del js_data # try to load the instance of template function # map_data is a dict which records build information if map_data.get('version') > OP_LOAD_MODULE_BUILD_VERSION: portalocker.unlock(build_info_fs) raise Exception( """Unsupported higher version %s of wrapper file (Current MobulaOP ver: %s) :-(. Please update MobulaOP.""" % (map_data.get('version'), OP_LOAD_MODULE_BUILD_VERSION)) build_id = map_data.get('build_id', 0) is_old_version = map_data.get( 'version') < OP_LOAD_MODULE_BUILD_VERSION # load the information of template functions ORDINARY_FUNCTION_NAME = 'ordinary_functions' TEMPLATE_FUNCTION_NAME = 'template_functions' if is_old_version: ordinary_functions = list() template_functions = dict() else: ordinary_functions = map_data.get( ORDINARY_FUNCTION_NAME, list()) template_functions = map_data.get( TEMPLATE_FUNCTION_NAME, dict()) so_prefix = os.path.join( cpp_path, 'build', os.path.splitext(cpp_basename)[0]) # The filename of build target dll_fname_format = '{prefix}_{ctx}'.format( prefix=so_prefix, ctx=ctx) + '_{build_id}.so' dll_fname = dll_fname_format.format(build_id=build_id) file_changed = file_is_changed(cpp_fname) dll_existed = os.path.exists(dll_fname) func_existed = idcode in template_functions or idcode in ordinary_functions if file_changed or not dll_existed or not func_existed or is_old_version: # Rebuild DLL file try: # try to remove old DLL file os.remove(dll_fname) except: pass if file_changed: # clear template_functions since some functions may have been deleted or renamed after codefile is changed. template_functions.clear() if file_changed or not func_existed: ''' we increase `build_id` by 1 when one of the following conditions is True: 1. the cpp file has been changed 2. new idcode When the cpp file is not changed, and idcode exists in template_functions, `build_id` will be not changed. ''' build_id += 1 dll_fname = dll_fname_format.format(build_id=build_id) # build code code_buffer = _generate_ordinary_code(cpp_info) ordinary_functions = _get_ordinary_functions(cpp_info) if use_template: if idcode not in template_functions: _update_template_inst_map( idcode, template_functions, cfunc, arg_types) # add template instances code into code_buffer code_buffer += ''.join(template_functions.values()) with build_context(): try: _build_lib(cpp_fname, code_buffer, ctx, dll_fname) except: # if build fail, unlock the build info file portalocker.unlock(build_info_fs) raise # update template_functions map_data = dict(version=OP_LOAD_MODULE_BUILD_VERSION, build_id=build_id) map_data[ORDINARY_FUNCTION_NAME] = ordinary_functions map_data[TEMPLATE_FUNCTION_NAME] = template_functions # clear the old context and write json data build_info_fs.seek(0) build_info_fs.truncate() json.dump(map_data, build_info_fs) build_info_fs.flush() os.fsync(build_info_fs.fileno()) portalocker.unlock(build_info_fs) # load all functions in the dll cpp_info.load_dll(dll_fname) # import all functions # ordinary functions for func_name, ord_cfunc in cpp_info.function_args.items(): if not ord_cfunc.template_list: func_idcode = get_func_idcode( func_name, ord_cfunc.arg_types) _add_function(func_map, func_idcode, cpp_info, dll_fname) # template functions for func_idcode in template_functions.keys(): _add_function(func_map, func_idcode, cpp_info, dll_fname) self.func = func_map[idcode].func self.cpp_info = func_map[idcode].cpp_info self.idcode_hash = get_idcode_hash(idcode) def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) def get_async_func(self, glue_mod): async_name = getattr(glue_mod, 'async_name', None) if async_name is None: return None return glue_mod.get_async_func(self.cpp_info, self.idcode_hash) def _get_functions_from_cpp(cpp_fname): unmatched_brackets = 0 func_def = '' func_kind = '' func_started = False template_list = [] cpp_info = CPPInfo(cpp_fname=cpp_fname) function_args = cpp_info.function_args for line in open(cpp_fname): if not func_started: current_template_list = _get_template_decl(line) if current_template_list is not None: template_list = current_template_list match = MOBULA_KERNEL_REG.search(line) if match is not None: func_def = '' func_kind_str = match.groups()[0] if func_kind_str == 'KERNEL': func_kind = CFuncDef.KERNEL elif func_kind_str == 'FUNC': func_kind = CFuncDef.FUNC else: raise TypeError( 'Unknown kind of function: %s' % func_kind_str) func_started = True # In a declaration of a function if func_started: unmatched_brackets += line.count('(') - line.count(')') func_def += line if unmatched_brackets == 0: func_def = func_def.replace('\n', '').replace('\r', '') func_started = False rtn_type, kernel_name, par_list = parse_parameters_list( func_def) # template name check template_set = set(template_list) assert len(template_set) == len(template_list),\ Exception('Duplicated template name in {}'.format( ', '.join(template_list))) use_template = False for dtype, _ in par_list: if isinstance(dtype, TemplateType): assert dtype.tname in template_set,\ Exception( "template name '{}' is not defined".format(dtype.tname)) use_template = True if not use_template: template_list = [] if func_kind == CFuncDef.KERNEL: assert kernel_name.endswith('_kernel'),\ Exception('the postfix of a MOBULA_KERNEL name must be `_kernel`, \ e.g. addition_forward_kernel') func_name = kernel_name[:-len('_kernel')] elif func_kind == CFuncDef.FUNC: func_name = kernel_name else: raise Exception( 'Unknown function kind: {}'.format(func_kind)) # Arguments funcdef_args = edict(func_name=func_name, func_kind=func_kind, arg_names=[t[1] for t in par_list], arg_types=[t[0] for t in par_list], rtn_type=rtn_type, template_list=template_list, loader=OpLoader, loader_kwargs=dict( cpp_info=cpp_info, ) ) template_list = [] function_args[func_name] = funcdef_args assert unmatched_brackets == 0,\ Exception('# unmatched brackets: {}'.format(unmatched_brackets)) # Load dynamic file functions = dict( (name, CFuncDef(**kwargs)) for name, kwargs in function_args.items()) # Load dynamic function for MXNet return functions def load(module_name, path=''): """Load Operator Module Parameters ---------- module_name: str The name of Operator Module path: str The path of Operator Module [default = current path] """ op_name = os.path.basename(module_name) if not path: # Find Operator Module in custom directory first custom_path = os.path.join(os.path.dirname(__file__), '../../opzoo') if os.path.exists(os.path.join(custom_path, op_name)): path = custom_path path = os.path.join(path, module_name) found = False cpp_fname = os.path.join(path, op_name + '.cpp') if os.path.exists(cpp_fname): found = True # Get functions functions = _get_functions_from_cpp(cpp_fname) bind(functions) py_fname = os.path.join(path, op_name + '.py') if not os.path.exists(py_fname): py_fname = os.path.join(path, '__init__.py') if os.path.exists(py_fname): found = True # Create Operator load_module(op_name, py_fname) assert found,\ IOError("{op_name}.cpp or {op_name}.py or __init__.py not found\ in the path {path}".format(op_name=op_name, path=path))