from restrain_jit.becython.cy_jit_hotspot_comp import Strategy, HitN, JITRecompilationDecision, extension_type_pxd_paths, is_jit_able_type from restrain_jit.becython.cy_method_codegen import CodeEmit, UndefGlobal from restrain_jit.becython.cy_jit_common import * from restrain_jit.becython.cython_vm import CyVM from restrain_jit.becython.cy_jit_ext_template import mk_module_code, mk_call_record_t, mk_hard_coded_method_getter_module from restrain_jit.becython.cy_loader import compile_module, JIT_FUNCTION_DIR, JIT_TYPE_DIR, ximport_module from restrain_jit.jit_info import PyFuncInfo, PyCodeInfo from restrain_jit.utils import CodeOut from restrain_jit.becython import cy_jit_common as cmc from types import FunctionType, ModuleType import typing as t class JITFunctionHoldPlace: memoize_partial_code: t.Optional[CodeOut] method_arguments: t.Optional[t.Tuple[str, ...]] function_module: ModuleType base_method_module: ModuleType methods: t.Dict[call_record_t, fptr_addr] def __init__(self, jit_sys: 'JITSystem', code_info: PyCodeInfo): self.sys = jit_sys self.globals = {} self.methods: t.Dict[call_record_t, fptr_addr] = {} # do not access code_info.glob_deps, # due to some global var like `type`, `len` are removed from here self.glob_deps = () self.name_that_makes_sense = code_info.name self.argc = code_info.argcount self.base_method_code = "" self.base_method_addr = None self._function_module = None self._fn_ptr_name = None self.call_recorder = None self.memoize_partial_code = None self.method_arguments = None @property def function_module(self): assert self._function_module return self._function_module @function_module.setter def function_module(self, v): self._function_module = v @property def fn_ptr_name(self): assert self._fn_ptr_name return self._fn_ptr_name @fn_ptr_name.setter def fn_ptr_name(self, v): self._fn_ptr_name = v @property def unique_index(self): return id(self) def counter_name(self) -> str: raise NotImplementedError def globals_from(self, globs: dict): self.globals = globs def add_methods(self, dispatcher: t.Dict[call_record_t, t.Tuple[type, ...]]): if not dispatcher: return for argtypeids, argtypes in dispatcher.items(): self.add_method(argtypeids, argtypes) self._rebuild_method_getter_and_set() def add_method(self, argtypeids: call_record_t, argtypes: t.Tuple[type, ...]): self.sys.add_jit_method(self, argtypeids, argtypes) def _rebuild_method_getter_and_set(self): code = mk_hard_coded_method_getter_module(self.methods, self.base_method_addr, self.argc) mod = compile_module(JIT_FUNCTION_DIR, "MethodGetter", code) self.function_module.f.mut_method_get(mod.method_get_addr) class JITSystem: def __init__(self, strategies: t.List[Strategy] = None): strategies = strategies or [HitN(200)] self.fn_count = 0 self.jit_hotspot_analysis = JITRecompilationDecision(strategies) self.memoize_partial_code = {} self.jit_fptr_index = {} self.fn_place_index = {} # type: t.Dict[int, JITFunctionHoldPlace] self.store_base_method_log = False def jitdata(self, cls: type): undef = object() anns = {} code = [] imports = [] for i, (k, t) in enumerate(cls.__annotations__.items()): path = extension_type_pxd_paths.get(t, undef) if path is undef: anns[k] = 'object' elif path is None: anns[k] = t.__name__ else: typename = "{}{}".format(typed_head, i) anns[k] = typename imports.append('from {} cimport {} as {}'.format( path, t.__name__, typename)) code.append("cdef class {}:".format(cls.__name__)) for attr, typestr in anns.items(): code.append(" cdef {} x_{}".format(typestr, attr)) code.append(" def __init__(self, {}):".format(', '.join(anns))) for attr, _ in anns.items(): code.append(" self.x_{0} = {0}".format(attr)) for attr, typestr in anns.items(): code.append(" cpdef {1} get_{0}(self):".format(attr, typestr)) code.append(" return self.x_{0}".format(attr)) code.append(" @property") code.append(" def {0}(self):".format(attr)) code.append(" return self.x_{0}".format(attr)) code.append(" cpdef void set_{0}(self, {1} {0}):".format(attr, typestr)) code.append(" self.x_{0} = {0}".format(attr)) code.append(" @{}.setter".format(attr)) code.append(" def {0}(self, {1} {0}):".format(attr, typestr)) code.append(" self.x_{0} = {0}".format(attr)) pyx = '\n'.join(imports + code) code.clear() code.append("cdef class {}:".format(cls.__name__)) for attr, typestr in anns.items(): code.append(" cdef {} x_{}".format(typestr, attr)) for attr, typestr in anns.items(): code.append(" cpdef void set_{0}(self, {1})".format(attr, typestr)) code.append(" cpdef {1} get_{0}(self)".format(attr, typestr)) pxd = '\n'.join(imports + code) mod = ximport_module(JIT_TYPE_DIR, cls.__name__, pyx, pxd) return getattr(mod, cls.__name__) def jit(self, f: FunctionType): func_info = self.get_func_info(f) code_info = func_info.r_codeinfo fn_place = self.allocate_place_for_function(code_info) fn_place.globals_from(func_info.r_globals) CodeEmit(self, code_info, fn_place) f = fn_place.function_module.f self.fn_place_index[id(f)] = fn_place return f @staticmethod def get_func_info(f: FunctionType) -> PyFuncInfo: return CyVM.func_info(f) def generate_module_for_code(self, code_info: PyCodeInfo): code = mk_module_code(code_info) unique_module_name = "Functions_{}".format(self.fn_count) mod = compile_module(JIT_FUNCTION_DIR, unique_module_name, code) return mod def remember_partial_code(self, fn_place: JITFunctionHoldPlace, code_out: CodeOut): fn_place.memoize_partial_code = code_out self.memoize_partial_code[id(fn_place)] = code_out def setup(self): # FIXME: create the temporary directory, # as the place to hold all compiled extensions and, # the source codes required by re-JIT. # ONE RUNTIME USES ONE JIT DIRECTORY raise NotImplementedError def allocate_place_for_function(self, code_info: PyCodeInfo) -> JITFunctionHoldPlace: # use object id and function name and module name to # generate unique path as well as meaningful. fn_place = JITFunctionHoldPlace(self, code_info) fn_place.function_module = self.generate_module_for_code(code_info) self.jit_fptr_index[id(fn_place)] = fn_place.function_module.f return fn_place def generate_base_method(self, function_place: JITFunctionHoldPlace, code): argc = function_place.argc method_argtype_comma_lst = ', '.join('object' for _ in range(argc)) method_get_argument_comma_lst = ', '.join('int64_t _%d' % i for i in range(argc)) # TODO: use auto-evolution-able method lookup fn_ptr_name = function_place.fn_ptr_name code = """ {} ctypedef object (*method_t)({}) cdef method_t this_method_getter({}): return {} base_method_addr = reinterpret_cast[int64_t](<void*>{}) method_getter_addr = reinterpret_cast[int64_t](<void*>this_method_getter) """.format(code, method_argtype_comma_lst, method_get_argument_comma_lst, fn_ptr_name, fn_ptr_name) if self.store_base_method_log: function_place.base_method_code = code unique_module = "Methods_{}_{}_base_method".format( id(function_place), function_place.name_that_makes_sense.replace('.', '__')) mod = compile_module(JIT_FUNCTION_DIR, unique_module, code) method_init_fptrs = getattr(mod, cmc.method_init_fptrs) method_init_globals = getattr(mod, cmc.method_init_globals) init_notifier = getattr(mod, cmc.method_init_recorder_and_notifier) method_getter_addr = getattr(mod, 'method_getter_addr') call_recorder = mod.NonJITCallRecorder() function_place.call_recorder = call_recorder init_notifier( call_recorder, lambda: self.jit_hotspot_analysis.trigger_jit(function_place)) g = function_place.globals method_init_globals(**{k: g[k] for k in function_place.glob_deps}) method_init_fptrs(self.jit_fptr_index) function_place.base_method_module = mod function_place.function_module.f.mut_method_get(method_getter_addr) function_place.base_method_addr = mod.base_method_addr def add_jit_method(self, function_place: JITFunctionHoldPlace, argtypeids: call_record_t, argtypes: t.Tuple[type, ...]): method_arguments = function_place.method_arguments fn_ptr_name = function_place.fn_ptr_name actual_args = [cmc.typed_head + arg for arg in method_arguments] once_code_out = CodeOut() declaring: list = once_code_out[cmc.Import] typing: list = once_code_out[cmc.Customizable] typing.append("cdef {}({}):".format(fn_ptr_name, ', '.join(actual_args))) undef = object() for i, (actual_arg, arg, argtype) in enumerate( zip(actual_args, method_arguments, argtypes)): path = extension_type_pxd_paths.get(argtype, undef) if path is undef: # well, this type cannot JIT in fact, so still object and stop recording it. typing.append("{}{} = {}".format(cmc.IDENTATION_SECTION, arg, actual_arg)) elif path is None: # it's builtin extension types, like dict, list, etc. typing.append("{}cdef {} {} = {}".format(cmc.IDENTATION_SECTION, argtype.__name__, arg, actual_arg)) else: # Good, user-defined extension types! # You'll see how fast it'll be! # Firstly we import the required type typename = "{}{}".format(cmc.typed_head, i) declaring.append('from {} cimport {} as {}'.format( path, argtype.__name__, typename)) typing.append("{}cdef {} {} = {}".format(cmc.IDENTATION_SECTION, typename, arg, actual_arg)) once_code_out.merge_update(function_place.memoize_partial_code) code = '\n'.join(once_code_out.to_code_lines()) # TODO: use auto-evolution-able method lookup code = """ {} method_addr = reinterpret_cast[int64_t](<void*>{}) """.format(code, fn_ptr_name) unique_module = "Methods_{}_{}_JITed".format( id(function_place), function_place.name_that_makes_sense.replace('.', '__')) mod = compile_module(JIT_FUNCTION_DIR, unique_module, code) method_init_fptrs = getattr(mod, cmc.method_init_fptrs) method_init_globals = getattr(mod, cmc.method_init_globals) g = function_place.globals method_init_globals(**{k: g[k] for k in function_place.glob_deps}) method_init_fptrs(self.jit_fptr_index) function_place.methods[argtypeids] = mod.method_addr