# coding: utf-8 """基于SQLAlchemy的ORM插件 """ import re import ujson import types import threading import collections from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.query import Query as ORMQuery from sqlalchemy.pool import Pool, QueuePool, NullPool from sqlalchemy.ext.automap import automap_base from girlfriend.exception import InvalidStatusException from girlfriend.util.validating import Rule, be_json from girlfriend.util.lang import ( args2fields, SequenceCollectionType, parse_context_var ) class EngineManager(object): """通过EngineManager来管理所有的Engine对象 这样所有的SQLAlchemy插件就可以实现Engine共享 """ STATUS_UNINIT = 0 # 尚未初始化 STATUS_OK = 1 # 状态OK STATUS_DISPOSED = 2 # 已经销毁 # 数据源配置项的验证规则 config_rules = ( # 连接字符串,比如postgresql://scott:tiger@localhost/test Rule("connect_url", required=True, type=types.StringTypes, regex=r"^\w+://"), # 编码 Rule("encoding", required=False, type=types.StringTypes, default="utf-8"), # 连接参数,采用json的形式:{'ssl': {'cert': xxx, 'key': xxx, 'ca': xxx}} Rule("connect_args", required=False, type=types.StringTypes, default="{}", logic=be_json("connect_args")), # 该数据源使用的连接池策略 Rule("pool_policy", required=False, type=types.StringTypes, default=None) ) # 连接池的配置项验证规则,section名称皆以dbpool_开头 pool_config_rules = ( # 连接池类型,比如sqlalchemy.pool.QueuePool,可以是字符串,也可以是具体的类对象 Rule("poolclass", required=True, type=(Pool, types.StringTypes)), # 连接池大小 Rule("pool_size", required=False, type=(int, types.StringTypes), regex=r"^\d+$", default=10), # 连接回收周期,单位为秒,MySQL尤其要注意设置此项,否则连接会自动过期断开 Rule("pool_recycle", required=False, type=(int,), regex=r"^\d+$", default=3600), # 从连接池中获取连接的超时时间,SQLAlchemy自带连接池中,只能用于QueuePool Rule("pool_timeout", required=False, type=( int,), regex=r"^\d+$", default=30), ) def __init__(self): self._engines = {} self._status = EngineManager.STATUS_UNINIT self._validated = False def validate_config(self, config): """统一对配置进行验证,避免各个插件单独验证 """ if self._validated: return for section in config.prefix("db_", "dbpool_"): if section.startswith("db_"): self._validate_config_items( config, section, EngineManager.config_rules) elif section.startswith("dbpool_"): self._validate_config_items( config, section, EngineManager.pool_config_rules) self._validated = True def _validate_config_items(self, config, section, rules): config_items = config[section] for rule in rules: item_value = config_items.get(rule.name) rule.validate(item_value) if item_value is None and not rule.required: config[section][rule.name] = rule.default def init_all(self, config): """统一初始化 """ if self._status != EngineManager.STATUS_UNINIT: return all_pool_config = self._load_db_pool_config(config) for section in config.prefix("db_"): engine_name = section.split("_", 1)[1] engine = self._create_engine(config, section, all_pool_config) self._engines[engine_name] = EngineContainer(engine) self._status = EngineManager.STATUS_OK def _load_db_pool_config(self, config): """对连接池进行初始化 """ pool_config = {} for section in config.prefix("dbpool_"): pool_name = section.split("_", 1)[1] pool_config[pool_name] = config[section] # 对字符串的配置项做一点微小的工作 poolclass = pool_config[pool_name]["poolclass"] if isinstance(poolclass, str): if poolclass.lower() == "queuepool": poolclass = QueuePool elif poolclass.lower() == "nullpool": poolclass = NullPool else: module_name, class_name = poolclass.rsplit(".", 1) poolclass = getattr(__import__(module_name), class_name) pool_config[pool_name]["poolclass"] = poolclass for item_name in ("pool_size", "pool_recycle", "pool_timeout"): item_value = pool_config[pool_name][item_name] pool_config[pool_name][item_name] = int(item_value) return pool_config def _create_engine(self, config, section, all_pool_config): config_items = config[section] pool_policy = config[section]["pool_policy"] kws = {} if pool_policy is not None: pool_config = all_pool_config[pool_policy] if pool_config["poolclass"] != NullPool: kws = pool_config return create_engine( config_items["connect_url"], encoding=config_items["encoding"], connect_args=ujson.loads(config_items["connect_args"]), ** kws ) def engine(self, engine_name): """根据引擎名字来获取引擎 """ if self._status != EngineManager.STATUS_OK: raise InvalidStatusException(u"Engine尚未初始化") return self._engines[engine_name] def dispose_all(self): """统一对引擎进行销毁 """ if self._status != EngineManager.STATUS_OK: return for _, engine_container in self._engines.items(): engine_container.engine.dispose() self._status = EngineManager.STATUS_DISPOSED _engine_manager = EngineManager() class EngineContainer(object): """存储engine对象,并封装/代理一些与engine对象有关的操作 """ def __init__(self, engine): self.engine = engine self.sessionmaker = sessionmaker(bind=engine) self._base_model_class = None self._base_model_class_sem = threading.Lock() def session(self): return self.sessionmaker() @property def base_model_class(self): if self._base_model_class is not None: return self._base_model_class # 并发环境下,防止重复初始化 with self._base_model_class_sem: # 虽然已经获得了锁,但是先人已经初始化了,立即返回 if self._base_model_class is not None: return self._base_model_class self._base_model_class = automap_base() self._base_model_class.prepare(self.engine, reflect=True) return self._base_model_class class OrmQueryPlugin(object): name = "orm_query" @staticmethod def config_validator(config): """配置验证器 """ global _engine_manager _engine_manager.validate_config(config) def __init__(self): self._engines = {} def sys_prepare(self, config): global _engine_manager _engine_manager.init_all(config) def execute(self, context, *exec_list): # 按顺序执行exec_list return [exec_(context) for exec_ in exec_list] def sys_cleanup(self, config): global _engine_manager _engine_manager.dispose_all() _SELECT_STATEMENT_REGEX = re.compile("^select ", re.IGNORECASE) class Query(object): """描述ORM查询信息以及执行查询操作 """ @args2fields() def __init__(self, engine_name, variable_name, query_items, query=None, order_by=None, group_by=None, params=None, row_handler=None, result_wrapper=None): """ :param engine_name 使用的引擎名称 :param variable_name 查询的结果将以此为变量名写入context :param query_items 查询项目,可以指定class类型和字符串类型 如果是字符串类型,那么会启用auto map进行处理 :param query 接受回调函数或者是SQL以及字符串描述的查询条件 :param params 如果基于文本或者SQL查询,那么可以通过此字段来传递参数 :param row_handler 行处理器,针对每一行做格式转换操作 :param result_wrapper 用于对查询结果进行包装,比如将查询结果包装成table对象 """ if isinstance(order_by, str): self._order_by = text(order_by) def __call__(self, context): global _engine_manager engine = _engine_manager.engine(self._engine_name) # 处理查询项 query_items = [] if isinstance(self._query_items, types.StringTypes): query_items.append(self.automap(engine, self._query_items)) elif isinstance(self._query_items, collections.Iterable): for query_item in self._query_items: if isinstance(query_item, types.StringTypes): query_item = self.automap(engine, query_item) query_items.append(query_item) else: query_items.append(self._query_items) # 解析params中的context变量 if self._params: self._params = { key: parse_context_var(context, self._params[key]) for key in self._params} session = engine.session() try: query = None if isinstance(self._query, types.StringTypes): if _SELECT_STATEMENT_REGEX.search(self._query): if self._params is None: self._params = {} query = session.query(*query_items).from_statement( text(self._query)).params(**self._params) else: query = self._build_query(engine, session, query_items) elif isinstance(self._query, types.FunctionType): result = self._query(session, context, *query_items) # 可以返回查询对象,也可以返回查询结果 if isinstance(result, ORMQuery): query = result else: if self._row_handler: result = tuple(self._row_handler(row) for row in result) if self._result_wrapper is not None: result = self._result_wrapper(result) context[self._variable_name] = result return result else: query = self._build_query(engine, session, query_items) result = self._build_result(query) context[self._variable_name] = result return result finally: session.close() def automap(self, engine, query_item_str): base = engine.base_model_class if "." in query_item_str: table_name, field_name = query_item_str.split(".", 1) clazz = getattr(base.classes, table_name, None) return getattr(clazz, field_name, None) else: return getattr(base.classes, query_item_str, None) def _build_query(self, engine, session, query_items): query = session.query(*query_items) if self._query is not None: if isinstance(self._query, types.StringTypes): query = query.filter(text(self._query)) else: query = query.filter(self._query) if self._order_by is not None: query = query.order_by(self._order_by) if self._group_by is not None: if isinstance(self._group_by, types.StringTypes): self._group_by = self.automap(engine, self._group_by) query = query.group_by(self._group_by) if self._params is not None: query = query.params(**self._params) return query def _build_result(self, query): if self._row_handler: result = tuple(self._row_handler(row) for row in query) else: result = query.all() if self._result_wrapper is not None: return self._result_wrapper(result) return result class SQL(object): """使用该对象可以直接描述非ORM的SQL查询 """ @args2fields() def __init__(self, engine_name, variable_name, sql, params=None, row_handler=None, result_wrapper=None): if params is None: self._params = {} def __call__(self, context): global _engine_manager engine_container = _engine_manager.engine(self._engine_name) session = engine_container.session() if isinstance(self._sql, types.StringTypes) and \ _SELECT_STATEMENT_REGEX.search(self._sql): return self._execute_select_statement(session, context) else: # 非查询语句 try: if isinstance(self._sql, types.StringTypes): session.execute(self._sql, self._params) # 批量执行 elif isinstance(self._sql, SequenceCollectionType): if isinstance(self._params, SequenceCollectionType): for idx, sql in enumerate(self._sql): session.execute(sql, self._params[idx]) else: for sql in self._sql: session.execute(sql) session.commit() finally: session.close() def _execute_select_statement(self, session, context): try: result_proxy = session.execute(self._sql, self._params) result = None if self._row_handler is not None: result = tuple(self._row_handler(row) for row in result_proxy) else: result = tuple(tuple(row) for row in result_proxy) if self._result_wrapper is not None: result = self._result_wrapper(result) context[self._variable_name] = result result_proxy.close() return result finally: session.close() class KeyExtractWrapper(object): """本Handler可以将一行中的某个字段转变为Key,并将结果包装为一个字典结构 例如以第一列为Key:((1, 2, 3, 4, 5), ) => {1: (1, 2, 3, 4, 5)} 以某个属性为Key: (User(id=0, name="SamChi", age=31),) => {"SamChi": User(...)} """ def __init__(self, key_index): """ :param key_index 作为索引,可以是数字,也可以是字符串属性名 具体由行类型来决定 """ self._key_index = key_index def __call__(self, query_result): if not query_result: return {} row_0 = query_result[0] if isinstance(row_0, (types.ListType, types.TupleType, types.DictType)): return {row[self._key_index]: row for row in query_result} return {getattr(row, self._key_index): row for row in query_result}